diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu index e38b47bf3..82f01fabd 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/update_inputs_v1.xpu @@ -10,24 +10,25 @@ namespace xpu3 { namespace plugin { -__global__ void update_inputs_v1(bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, +__global__ void update_inputs_v1(bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int64_t* prompt_lens, + int64_t* topk_ids, + int64_t* input_ids, + int* block_tables, + const int64_t* stop_nums, + bool* stop_flags, + bool* is_block_step, + const int64_t* next_tokens, const int bsz, const int max_bsz, const int input_ids_stride, const int block_num_per_seq, - const int block_size) { + const int block_size, + bool prefill_one_step_stop) { // std::cout << "seq_lens_this_time " << seq_lens_this_time[0] << std::endl; int cid = core_id(); int ncores = core_num(); @@ -68,39 +69,55 @@ __global__ void update_inputs_v1(bool *not_need_stop, seq_len_this_time_update + seq_len_decoder_update; int prompt_lens_update = 0; GM2LM(prompt_lens + i, &prompt_lens_update, sizeof(int64_t)); - // decoding if (sum_of_seq_lens_this_time_and_seq_lens_decoder >= prompt_lens_update) { - seq_len_decoder_update = - seq_len_this_time_update + seq_len_decoder_update; - LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); - seq_len_this_time_update = 1; - LM2GM(&seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); - seq_lens_encoder_update = 0; - LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); - int64_t input_ids_update; - GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t)); - LM2GM(&input_ids_update, - input_ids + i * input_ids_stride, - sizeof(int64_t)); - // to judge whether block is not enough - if (seq_len_this_time_update != 0 && - block_tables[i * block_num_per_seq + - seq_len_decoder_update / block_size] == -1) { - is_block_step[i] = true; - seq_len_this_time_update = 0; - LM2GM( - &seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + if (prefill_one_step_stop) { + // prefill done, stop stop_flags_sm[i] = true; SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool)); - LM2GM(&seq_len_decoder_update, - step_seq_lens_decoder + i, - sizeof(int)); - seq_len_decoder_update = 0; - LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + seq_len_this_time_update = 0; seq_len_decoder_update = 0; + seq_lens_encoder_update = 0; + LM2GM( + &seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); stop_flags_int_sm[i] = 1; + } else { + // decoding + seq_len_decoder_update = + seq_len_this_time_update + seq_len_decoder_update; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + seq_len_this_time_update = 1; + LM2GM( + &seq_len_this_time_update, seq_lens_this_time + i, sizeof(int)); + seq_lens_encoder_update = 0; + LM2GM(&seq_lens_encoder_update, seq_lens_encoder + i, sizeof(int)); + int64_t input_ids_update; + GM2LM(next_tokens + i, &input_ids_update, sizeof(int64_t)); + LM2GM(&input_ids_update, + input_ids + i * input_ids_stride, + sizeof(int64_t)); + // to judge whether block is not enough + if (seq_len_this_time_update != 0 && + block_tables[i * block_num_per_seq + + seq_len_decoder_update / block_size] == -1) { + is_block_step[i] = true; + seq_len_this_time_update = 0; + LM2GM(&seq_len_this_time_update, + seq_lens_this_time + i, + sizeof(int)); + stop_flags_sm[i] = true; + SM2GM(stop_flags_sm + i, stop_flags + i, sizeof(bool)); + LM2GM(&seq_len_decoder_update, + step_seq_lens_decoder + i, + sizeof(int)); + seq_len_decoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + seq_len_decoder_update = 0; + LM2GM(&seq_len_decoder_update, seq_lens_decoder + i, sizeof(int)); + stop_flags_int_sm[i] = 1; + } } } else { stop_flags_sm[i] = true; diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp index 7fe1772c4..dd937686b 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/update_inputs_v1.cpp @@ -20,24 +20,25 @@ namespace xpu3 { namespace plugin { -__attribute__((global)) void update_inputs_v1(bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, +__attribute__((global)) void update_inputs_v1(bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int64_t* prompt_lens, + int64_t* topk_ids, + int64_t* input_ids, + int* block_tables, + const int64_t* stop_nums, + bool* stop_flags, + bool* is_block_step, + const int64_t* next_tokens, const int bsz, const int max_bsz, const int input_ids_stride, const int block_num_per_seq, - const int block_size); + const int block_size, + bool prefill_one_step_stop); } // namespace plugin } // namespace xpu3 @@ -47,20 +48,20 @@ namespace xpu { namespace api { namespace plugin { -static int xpu3_wrapper(Context *ctx, - bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, +static int xpu3_wrapper(Context* ctx, + bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int64_t* prompt_lens, + int64_t* topk_ids, + int64_t* input_ids, + int* block_tables, + const int64_t* stop_nums, + bool* stop_flags, + bool* is_block_step, + const int64_t* next_tokens, const int bsz, const int max_bsz, const int input_ids_stride, @@ -68,6 +69,12 @@ static int xpu3_wrapper(Context *ctx, const int block_size) { using XPU_INT64 = typename XPUIndexType::type; auto update_inputs_v1 = xpu3::plugin::update_inputs_v1; + bool prefill_one_step_stop = false; + if (const char* env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP_V1")) { + if (env_p[0] == '1') { + prefill_one_step_stop = true; + } + } // kernel 内要做 reduce,只能用 1 个 cluster update_inputs_v1<<<1, 64, ctx->xpu_stream>>>( not_need_stop, @@ -75,36 +82,37 @@ static int xpu3_wrapper(Context *ctx, seq_lens_encoder, seq_lens_decoder, step_seq_lens_decoder, - reinterpret_cast(prompt_lens), - reinterpret_cast(topk_ids), - reinterpret_cast(input_ids), + reinterpret_cast(prompt_lens), + reinterpret_cast(topk_ids), + reinterpret_cast(input_ids), block_tables, - reinterpret_cast(stop_nums), + reinterpret_cast(stop_nums), stop_flags, is_block_step, - reinterpret_cast(next_tokens), + reinterpret_cast(next_tokens), bsz, max_bsz, input_ids_stride, block_num_per_seq, - block_size); + block_size, + prefill_one_step_stop); return api::SUCCESS; } -int update_inputs_v1(Context *ctx, - bool *not_need_stop, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int64_t *prompt_lens, - int64_t *topk_ids, - int64_t *input_ids, - int *block_tables, - const int64_t *stop_nums, - bool *stop_flags, - bool *is_block_step, - const int64_t *next_tokens, +int update_inputs_v1(Context* ctx, + bool* not_need_stop, + int* seq_lens_this_time, + int* seq_lens_encoder, + int* seq_lens_decoder, + int* step_seq_lens_decoder, + int64_t* prompt_lens, + int64_t* topk_ids, + int64_t* input_ids, + int* block_tables, + const int64_t* stop_nums, + bool* stop_flags, + bool* is_block_step, + const int64_t* next_tokens, const int bsz, const int max_bsz, const int input_ids_stride,