[PD Disaggregation][XPU] update_inputs_v1 operator supports PD (#5550)

Co-authored-by: ddchenhao66 <dhaochen163.com>
This commit is contained in:
ddchenhao66
2025-12-15 15:39:38 +08:00
committed by GitHub
parent 97e340eb14
commit 9f70f4310e
2 changed files with 113 additions and 88 deletions

View File

@@ -10,24 +10,25 @@
namespace xpu3 {
namespace plugin {
__global__ void update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
__global__ void update_inputs_v1(bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size) {
const int block_size,
bool prefill_one_step_stop) {
// std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl;
int cid = core_id();
int ncores = core_num();
@@ -68,39 +69,55 @@ __global__ void update_inputs_v1(bool *not_need_stop,
seq_len_this_time_update + seq_len_decoder_update;
int prompt_lens_update = 0;
GM2LM(prompt_lens + i, &prompt_lens_update, sizeof(int64_t));
// decoding
if (sum_of_seq_lens_this_time_and_seq_lens_decoder >=
prompt_lens_update) {
seq_len_decoder_update =
seq_len_this_time_update + seq_len_decoder_update;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_this_time_update = 1;
LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_lens_encoder_update = 0;
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t input_ids_update;
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
LM2GM(&input_ids_update,
input_ids + i * input_ids_stride,
sizeof(int64_t));
// to judge whether block is not enough
if (seq_len_this_time_update != 0 &&
block_tables[i * block_num_per_seq +
seq_len_decoder_update / block_size] == -1) {
is_block_step[i] = true;
seq_len_this_time_update = 0;
LM2GM(
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
if (prefill_one_step_stop) {
// prefill done, stop
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
LM2GM(&seq_len_decoder_update,
step_seq_lens_decoder + i,
sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_this_time_update = 0;
seq_len_decoder_update = 0;
seq_lens_encoder_update = 0;
LM2GM(
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
stop_flags_int_sm[i] = 1;
} else {
// decoding
seq_len_decoder_update =
seq_len_this_time_update + seq_len_decoder_update;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_this_time_update = 1;
LM2GM(
&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int));
seq_lens_encoder_update = 0;
LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int));
int64_t input_ids_update;
GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t));
LM2GM(&input_ids_update,
input_ids + i * input_ids_stride,
sizeof(int64_t));
// to judge whether block is not enough
if (seq_len_this_time_update != 0 &&
block_tables[i * block_num_per_seq +
seq_len_decoder_update / block_size] == -1) {
is_block_step[i] = true;
seq_len_this_time_update = 0;
LM2GM(&seq_len_this_time_update,
seq_lens_this_time + i,
sizeof(int));
stop_flags_sm[i] = true;
SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool));
LM2GM(&seq_len_decoder_update,
step_seq_lens_decoder + i,
sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
seq_len_decoder_update = 0;
LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int));
stop_flags_int_sm[i] = 1;
}
}
} else {
stop_flags_sm[i] = true;

View File

@@ -20,24 +20,25 @@
namespace xpu3 {
namespace plugin {
__attribute__((global)) void update_inputs_v1(bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
__attribute__((global)) void update_inputs_v1(bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
const int block_num_per_seq,
const int block_size);
const int block_size,
bool prefill_one_step_stop);
} // namespace plugin
} // namespace xpu3
@@ -47,20 +48,20 @@ namespace xpu {
namespace api {
namespace plugin {
static int xpu3_wrapper(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
static int xpu3_wrapper(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,
@@ -68,6 +69,12 @@ static int xpu3_wrapper(Context *ctx,
const int block_size) {
using XPU_INT64 = typename XPUIndexType<int64_t>::type;
auto update_inputs_v1 = xpu3::plugin::update_inputs_v1;
bool prefill_one_step_stop = false;
if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) {
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}
// kernel 内要做 reduce只能用 1 个 cluster
update_inputs_v1<<<1, 64, ctx->xpu_stream>>>(
not_need_stop,
@@ -75,36 +82,37 @@ static int xpu3_wrapper(Context *ctx,
seq_lens_encoder,
seq_lens_decoder,
step_seq_lens_decoder,
reinterpret_cast<XPU_INT64 *>(prompt_lens),
reinterpret_cast<XPU_INT64 *>(topk_ids),
reinterpret_cast<XPU_INT64 *>(input_ids),
reinterpret_cast<XPU_INT64*>(prompt_lens),
reinterpret_cast<XPU_INT64*>(topk_ids),
reinterpret_cast<XPU_INT64*>(input_ids),
block_tables,
reinterpret_cast<const XPU_INT64 *>(stop_nums),
reinterpret_cast<const XPU_INT64*>(stop_nums),
stop_flags,
is_block_step,
reinterpret_cast<const XPU_INT64 *>(next_tokens),
reinterpret_cast<const XPU_INT64*>(next_tokens),
bsz,
max_bsz,
input_ids_stride,
block_num_per_seq,
block_size);
block_size,
prefill_one_step_stop);
return api::SUCCESS;
}
int update_inputs_v1(Context *ctx,
bool *not_need_stop,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *prompt_lens,
int64_t *topk_ids,
int64_t *input_ids,
int *block_tables,
const int64_t *stop_nums,
bool *stop_flags,
bool *is_block_step,
const int64_t *next_tokens,
int update_inputs_v1(Context* ctx,
bool* not_need_stop,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int* step_seq_lens_decoder,
int64_t* prompt_lens,
int64_t* topk_ids,
int64_t* input_ids,
int* block_tables,
const int64_t* stop_nums,
bool* stop_flags,
bool* is_block_step,
const int64_t* next_tokens,
const int bsz,
const int max_bsz,
const int input_ids_stride,