diff --git a/custom_ops/gpu_ops/get_padding_offset.cu b/custom_ops/gpu_ops/get_padding_offset.cu index 92c252f17..8fae9b88c 100644 --- a/custom_ops/gpu_ops/get_padding_offset.cu +++ b/custom_ops/gpu_ops/get_padding_offset.cu @@ -34,7 +34,7 @@ __global__ void RemovePadding(int64_t *output_data, } } -__global__ void GetPaddingOffsetKernel(int *padding_offset, +__global__ void GetPaddingOffsetKernel(int *batch_id_per_token, int *cum_offsets_out, int *cu_seqlens_q, int *cu_seqlens_k, @@ -46,7 +46,7 @@ __global__ void GetPaddingOffsetKernel(int *padding_offset, const int ti = threadIdx.x; int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { - padding_offset[bi * max_seq_len - cum_offset + i] = bi; + batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; } if (ti == 0) { cum_offsets_out[bi] = cum_offset; @@ -75,7 +75,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const int token_num_data = cpu_token_num.data()[0]; auto x_remove_padding = paddle::empty( {token_num_data}, paddle::DataType::INT64, input_ids.place()); - auto padding_offset = paddle::empty( + auto batch_id_per_token = paddle::empty( {token_num_data}, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); @@ -87,7 +87,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128); #endif GetPaddingOffsetKernel<<>>( - padding_offset.data(), + batch_id_per_token.data(), cum_offsets_out.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -102,7 +102,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, seq_length); return {x_remove_padding, cum_offsets_out, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; } @@ -133,7 +133,7 @@ PD_BUILD_STATIC_OP(get_padding_offset) .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) .Outputs({"x_remove_padding", "cum_offsets_out", - "padding_offset", + "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu index 2fbfff160..96186d761 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu @@ -41,7 +41,7 @@ __global__ void SpeculateRemovePadding(int64_t* output_data, } } -__global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset, +__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token, int* cum_offsets_out, int* cu_seqlens_q, int* cu_seqlens_k, @@ -53,7 +53,7 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset, const int ti = threadIdx.x; int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { - padding_offset[bi * max_seq_len - cum_offset + i] = bi; + batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi; } if (ti == 0) { cum_offsets_out[bi] = cum_offset; @@ -81,7 +81,7 @@ std::vector SpeculateGetPaddingOffset( const int token_num_data = cpu_token_num.data()[0]; auto x_remove_padding = paddle::full( {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); - auto padding_offset = paddle::full( + auto batch_id_per_token = paddle::full( {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); @@ -89,7 +89,7 @@ std::vector SpeculateGetPaddingOffset( paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128); SpeculateGetPaddingOffsetKernel<<>>( - padding_offset.data(), + batch_id_per_token.data(), cum_offsets_out.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -107,7 +107,7 @@ std::vector SpeculateGetPaddingOffset( max_draft_tokens); return {x_remove_padding, cum_offsets_out, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; } @@ -147,7 +147,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset) "seq_lens_encoder"}) .Outputs({"x_remove_padding", "cum_offsets_out", - "padding_offset", + "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset)) diff --git a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc index 203a8055d..e83cecb19 100644 --- a/custom_ops/xpu_ops/src/ops/get_padding_offset.cc +++ b/custom_ops/xpu_ops/src/ops/get_padding_offset.cc @@ -34,7 +34,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, const int token_num_data = cpu_token_num.data()[0]; auto x_remove_padding = paddle::full( {token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); - auto padding_offset = paddle::full( + auto batch_id_per_token = paddle::full( {token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); @@ -42,7 +42,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); int r = baidu::xpu::api::plugin::get_padding_offset( xpu_ctx->x_context(), - padding_offset.data(), + batch_id_per_token.data(), cum_offsets_out.data(), cu_seqlens_q.data(), cu_seqlens_k.data(), @@ -55,7 +55,7 @@ std::vector GetPaddingOffset(const paddle::Tensor &input_ids, PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed."); return {x_remove_padding, cum_offsets_out, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k}; } @@ -86,7 +86,7 @@ PD_BUILD_OP(get_padding_offset) .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) .Outputs({"x_remove_padding", "cum_offsets_out", - "padding_offset", + "batch_id_per_token", "cu_seqlens_q", "cu_seqlens_k"}) .SetKernelFn(PD_KERNEL(GetPaddingOffset)) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu index b5df4d743..5416b0045 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/get_padding_offset.xpu @@ -5,7 +5,7 @@ namespace xpu3 { namespace plugin { -__global__ void get_padding_offset(int *padding_offset, +__global__ void get_padding_offset(int *batch_id_per_token, int *cum_offsets_out, int *cu_seqlens_q, int *cu_seqlens_k, @@ -20,7 +20,7 @@ __global__ void get_padding_offset(int *padding_offset, int tid = clusterid * ncores + cid; int buf_len = 32; - __simd__ int padding_offset_lm[buf_len]; + __simd__ int batch_id_per_token_lm[buf_len]; __simd__ int cum_offsets_lm[16]; int seq_len_lm; for (int i = clusterid; i < bs; i += nclusters) { @@ -32,11 +32,11 @@ __global__ void get_padding_offset(int *padding_offset, for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) { int cur_len = min(seq_len_lm - j, buf_len); for (int k = 0; k < cur_len; k++) { - padding_offset_lm[k] = cum_offsets_lm[0]; + batch_id_per_token_lm[k] = i; } mfence_lm(); - LM2GM(padding_offset_lm, - padding_offset + i * max_seq_len - cum_offsets_lm[0] + j, + LM2GM(batch_id_per_token_lm, + batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j, cur_len * sizeof(int)); } if (cid == 0) { diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 0d3329c1d..50b1a3b9d 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -58,7 +58,7 @@ def xpu_pre_process( ( ids_remove_padding, cum_offsets, - padding_offset, + batch_id_per_token, cu_seqlens_q, cu_seqlens_k, ) = get_padding_offset(input_ids, cum_offsets_now, token_num, @@ -66,7 +66,7 @@ def xpu_pre_process( share_inputs["ids_remove_padding"] = None # set this after adjust batch share_inputs["cum_offsets"] = cum_offsets - share_inputs["padding_offset"] = padding_offset + share_inputs["batch_id_per_token"] = batch_id_per_token share_inputs["cu_seqlens_q"] = cu_seqlens_q share_inputs["cu_seqlens_k"] = cu_seqlens_k @@ -79,7 +79,7 @@ def xpu_pre_process( seq_lens_decoder=share_inputs["seq_lens_decoder"], seq_lens_this_time=share_inputs["seq_lens_this_time"], cum_offsets=share_inputs["cum_offsets"], - padding_offset=share_inputs["padding_offset"], + batch_id_per_token=share_inputs["batch_id_per_token"], cu_seqlens_q=share_inputs["cu_seqlens_q"], cu_seqlens_k=share_inputs["cu_seqlens_k"], block_tables=share_inputs["block_tables"],