mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation][XPU] update_inputs_v1 operator supports PD (#5550)
Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
@@ -10,24 +10,25 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__global__ void update_inputs_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
__global__ void update_inputs_v1(bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size) {
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop) {
|
||||
// std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl;
|
||||
int cid = core_id();
|
||||
int ncores = core_num();
|
||||
@@ -68,39 +69,55 @@ __global__ void update_inputs_v1(bool *not_need_stop,
|
||||
seq_len_this_time_update + seq_len_decoder_update;
|
||||
int prompt_lens_update = 0;
|
||||
GM2LM(prompt_lens + i, &prompt_lens_update, sizeof(int64_t));
|
||||
// decoding
|
||||
if (sum_of_seq_lens_this_time_and_seq_lens_decoder >=
|
||||
prompt_lens_update) {
|
||||
seq_len_decoder_update =
|
||||
seq_len_this_time_update + seq_len_decoder_update;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
seq_len_this_time_update = 1;
|
||||
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
int64_t input_ids_update;
|
||||
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
|
||||
LM2GM(&input_ids_update,
|
||||
input_ids + i * input_ids_stride,
|
||||
sizeof(int64_t));
|
||||
// to judge whether block is not enough
|
||||
if (seq_len_this_time_update != 0 &&
|
||||
block_tables[i * block_num_per_seq +
|
||||
seq_len_decoder_update / block_size] == -1) {
|
||||
is_block_step[i] = true;
|
||||
seq_len_this_time_update = 0;
|
||||
LM2GM(
|
||||
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
if (prefill_one_step_stop) {
|
||||
// prefill done, stop
|
||||
stop_flags_sm[i] = true;
|
||||
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
|
||||
LM2GM(&seq_len_decoder_update,
|
||||
step_seq_lens_decoder + i,
|
||||
sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
seq_len_this_time_update = 0;
|
||||
seq_len_decoder_update = 0;
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(
|
||||
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
stop_flags_int_sm[i] = 1;
|
||||
} else {
|
||||
// decoding
|
||||
seq_len_decoder_update =
|
||||
seq_len_this_time_update + seq_len_decoder_update;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
seq_len_this_time_update = 1;
|
||||
LM2GM(
|
||||
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
|
||||
seq_lens_encoder_update = 0;
|
||||
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
|
||||
int64_t input_ids_update;
|
||||
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
|
||||
LM2GM(&input_ids_update,
|
||||
input_ids + i * input_ids_stride,
|
||||
sizeof(int64_t));
|
||||
// to judge whether block is not enough
|
||||
if (seq_len_this_time_update != 0 &&
|
||||
block_tables[i * block_num_per_seq +
|
||||
seq_len_decoder_update / block_size] == -1) {
|
||||
is_block_step[i] = true;
|
||||
seq_len_this_time_update = 0;
|
||||
LM2GM(&seq_len_this_time_update,
|
||||
seq_lens_this_time + i,
|
||||
sizeof(int));
|
||||
stop_flags_sm[i] = true;
|
||||
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
|
||||
LM2GM(&seq_len_decoder_update,
|
||||
step_seq_lens_decoder + i,
|
||||
sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
seq_len_decoder_update = 0;
|
||||
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
|
||||
stop_flags_int_sm[i] = 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stop_flags_sm[i] = true;
|
||||
|
||||
@@ -20,24 +20,25 @@
|
||||
namespace xpu3 {
|
||||
namespace plugin {
|
||||
|
||||
__attribute__((global)) void update_inputs_v1(bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
__attribute__((global)) void update_inputs_v1(bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
const int block_num_per_seq,
|
||||
const int block_size);
|
||||
const int block_size,
|
||||
bool prefill_one_step_stop);
|
||||
|
||||
} // namespace plugin
|
||||
} // namespace xpu3
|
||||
@@ -47,20 +48,20 @@ namespace xpu {
|
||||
namespace api {
|
||||
namespace plugin {
|
||||
|
||||
static int xpu3_wrapper(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
static int xpu3_wrapper(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
@@ -68,6 +69,12 @@ static int xpu3_wrapper(Context *ctx,
|
||||
const int block_size) {
|
||||
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
|
||||
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
|
||||
bool prefill_one_step_stop = false;
|
||||
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
|
||||
if (env_p[0] == '1') {
|
||||
prefill_one_step_stop = true;
|
||||
}
|
||||
}
|
||||
// kernel 内要做 reduce,只能用 1 个 cluster
|
||||
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
|
||||
not_need_stop,
|
||||
@@ -75,36 +82,37 @@ static int xpu3_wrapper(Context *ctx,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
reinterpret_cast<XPU_INT64 *>(prompt_lens),
|
||||
reinterpret_cast<XPU_INT64 *>(topk_ids),
|
||||
reinterpret_cast<XPU_INT64 *>(input_ids),
|
||||
reinterpret_cast<XPU_INT64*>(prompt_lens),
|
||||
reinterpret_cast<XPU_INT64*>(topk_ids),
|
||||
reinterpret_cast<XPU_INT64*>(input_ids),
|
||||
block_tables,
|
||||
reinterpret_cast<const XPU_INT64 *>(stop_nums),
|
||||
reinterpret_cast<const XPU_INT64*>(stop_nums),
|
||||
stop_flags,
|
||||
is_block_step,
|
||||
reinterpret_cast<const XPU_INT64 *>(next_tokens),
|
||||
reinterpret_cast<const XPU_INT64*>(next_tokens),
|
||||
bsz,
|
||||
max_bsz,
|
||||
input_ids_stride,
|
||||
block_num_per_seq,
|
||||
block_size);
|
||||
block_size,
|
||||
prefill_one_step_stop);
|
||||
return api::SUCCESS;
|
||||
}
|
||||
|
||||
int update_inputs_v1(Context *ctx,
|
||||
bool *not_need_stop,
|
||||
int *seq_lens_this_time,
|
||||
int *seq_lens_encoder,
|
||||
int *seq_lens_decoder,
|
||||
int *step_seq_lens_decoder,
|
||||
int64_t *prompt_lens,
|
||||
int64_t *topk_ids,
|
||||
int64_t *input_ids,
|
||||
int *block_tables,
|
||||
const int64_t *stop_nums,
|
||||
bool *stop_flags,
|
||||
bool *is_block_step,
|
||||
const int64_t *next_tokens,
|
||||
int update_inputs_v1(Context* ctx,
|
||||
bool* not_need_stop,
|
||||
int* seq_lens_this_time,
|
||||
int* seq_lens_encoder,
|
||||
int* seq_lens_decoder,
|
||||
int* step_seq_lens_decoder,
|
||||
int64_t* prompt_lens,
|
||||
int64_t* topk_ids,
|
||||
int64_t* input_ids,
|
||||
int* block_tables,
|
||||
const int64_t* stop_nums,
|
||||
bool* stop_flags,
|
||||
bool* is_block_step,
|
||||
const int64_t* next_tokens,
|
||||
const int bsz,
|
||||
const int max_bsz,
|
||||
const int input_ids_stride,
|
||||
|
||||
Reference in New Issue
Block a user