mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[XPU] Remove padding_offsets from get_padding_offset.cu (#2911)
This commit is contained in:
@@ -34,7 +34,7 @@ __global__ void RemovePadding(int64_t *output_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void GetPaddingOffsetKernel(int *padding_offset,
|
__global__ void GetPaddingOffsetKernel(int *batch_id_per_token,
|
||||||
int *cum_offsets_out,
|
int *cum_offsets_out,
|
||||||
int *cu_seqlens_q,
|
int *cu_seqlens_q,
|
||||||
int *cu_seqlens_k,
|
int *cu_seqlens_k,
|
||||||
@@ -46,7 +46,7 @@ __global__ void GetPaddingOffsetKernel(int *padding_offset,
|
|||||||
const int ti = threadIdx.x;
|
const int ti = threadIdx.x;
|
||||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||||
padding_offset[bi * max_seq_len - cum_offset + i] = bi;
|
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||||
}
|
}
|
||||||
if (ti == 0) {
|
if (ti == 0) {
|
||||||
cum_offsets_out[bi] = cum_offset;
|
cum_offsets_out[bi] = cum_offset;
|
||||||
@@ -75,7 +75,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||||
auto x_remove_padding = paddle::empty(
|
auto x_remove_padding = paddle::empty(
|
||||||
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
{token_num_data}, paddle::DataType::INT64, input_ids.place());
|
||||||
auto padding_offset = paddle::empty(
|
auto batch_id_per_token = paddle::empty(
|
||||||
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
{token_num_data}, paddle::DataType::INT32, input_ids.place());
|
||||||
auto cu_seqlens_q =
|
auto cu_seqlens_q =
|
||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
@@ -87,7 +87,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
int blockSize = min((token_num_data + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE, 128);
|
||||||
#endif
|
#endif
|
||||||
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
GetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||||
padding_offset.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cum_offsets_out.data<int>(),
|
cum_offsets_out.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
cu_seqlens_k.data<int>(),
|
cu_seqlens_k.data<int>(),
|
||||||
@@ -102,7 +102,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
seq_length);
|
seq_length);
|
||||||
return {x_remove_padding,
|
return {x_remove_padding,
|
||||||
cum_offsets_out,
|
cum_offsets_out,
|
||||||
padding_offset,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||||
}
|
}
|
||||||
@@ -133,7 +133,7 @@ PD_BUILD_STATIC_OP(get_padding_offset)
|
|||||||
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
|
||||||
.Outputs({"x_remove_padding",
|
.Outputs({"x_remove_padding",
|
||||||
"cum_offsets_out",
|
"cum_offsets_out",
|
||||||
"padding_offset",
|
"batch_id_per_token",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"cu_seqlens_k"})
|
"cu_seqlens_k"})
|
||||||
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
||||||
|
@@ -41,7 +41,7 @@ __global__ void SpeculateRemovePadding(int64_t* output_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
|
__global__ void SpeculateGetPaddingOffsetKernel(int* batch_id_per_token,
|
||||||
int* cum_offsets_out,
|
int* cum_offsets_out,
|
||||||
int* cu_seqlens_q,
|
int* cu_seqlens_q,
|
||||||
int* cu_seqlens_k,
|
int* cu_seqlens_k,
|
||||||
@@ -53,7 +53,7 @@ __global__ void SpeculateGetPaddingOffsetKernel(int* padding_offset,
|
|||||||
const int ti = threadIdx.x;
|
const int ti = threadIdx.x;
|
||||||
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
|
||||||
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
|
||||||
padding_offset[bi * max_seq_len - cum_offset + i] = bi;
|
batch_id_per_token[bi * max_seq_len - cum_offset + i] = bi;
|
||||||
}
|
}
|
||||||
if (ti == 0) {
|
if (ti == 0) {
|
||||||
cum_offsets_out[bi] = cum_offset;
|
cum_offsets_out[bi] = cum_offset;
|
||||||
@@ -81,7 +81,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
|||||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||||
auto x_remove_padding = paddle::full(
|
auto x_remove_padding = paddle::full(
|
||||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||||
auto padding_offset = paddle::full(
|
auto batch_id_per_token = paddle::full(
|
||||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
auto cu_seqlens_q =
|
auto cu_seqlens_q =
|
||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
@@ -89,7 +89,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
|||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
|
||||||
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
SpeculateGetPaddingOffsetKernel<<<bsz, 128, 0, cu_stream>>>(
|
||||||
padding_offset.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cum_offsets_out.data<int>(),
|
cum_offsets_out.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
cu_seqlens_k.data<int>(),
|
cu_seqlens_k.data<int>(),
|
||||||
@@ -107,7 +107,7 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
|
|||||||
max_draft_tokens);
|
max_draft_tokens);
|
||||||
return {x_remove_padding,
|
return {x_remove_padding,
|
||||||
cum_offsets_out,
|
cum_offsets_out,
|
||||||
padding_offset,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
cu_seqlens_k}; // , enc_token_num, dec_token_num};
|
||||||
}
|
}
|
||||||
@@ -147,7 +147,7 @@ PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
|||||||
"seq_lens_encoder"})
|
"seq_lens_encoder"})
|
||||||
.Outputs({"x_remove_padding",
|
.Outputs({"x_remove_padding",
|
||||||
"cum_offsets_out",
|
"cum_offsets_out",
|
||||||
"padding_offset",
|
"batch_id_per_token",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"cu_seqlens_k"})
|
"cu_seqlens_k"})
|
||||||
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
|
.SetKernelFn(PD_KERNEL(SpeculateGetPaddingOffset))
|
||||||
|
@@ -34,7 +34,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
const int token_num_data = cpu_token_num.data<int64_t>()[0];
|
||||||
auto x_remove_padding = paddle::full(
|
auto x_remove_padding = paddle::full(
|
||||||
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
{token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
|
||||||
auto padding_offset = paddle::full(
|
auto batch_id_per_token = paddle::full(
|
||||||
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
{token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
auto cu_seqlens_q =
|
auto cu_seqlens_q =
|
||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
@@ -42,7 +42,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
|
||||||
int r = baidu::xpu::api::plugin::get_padding_offset(
|
int r = baidu::xpu::api::plugin::get_padding_offset(
|
||||||
xpu_ctx->x_context(),
|
xpu_ctx->x_context(),
|
||||||
padding_offset.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cum_offsets_out.data<int>(),
|
cum_offsets_out.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
cu_seqlens_k.data<int>(),
|
cu_seqlens_k.data<int>(),
|
||||||
@@ -55,7 +55,7 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
|||||||
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
|
PD_CHECK(r == 0, "baidu::xpu::api::plugin::get_padding_offset failed.");
|
||||||
return {x_remove_padding,
|
return {x_remove_padding,
|
||||||
cum_offsets_out,
|
cum_offsets_out,
|
||||||
padding_offset,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k};
|
cu_seqlens_k};
|
||||||
}
|
}
|
||||||
@@ -86,7 +86,7 @@ PD_BUILD_OP(get_padding_offset)
|
|||||||
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
|
||||||
.Outputs({"x_remove_padding",
|
.Outputs({"x_remove_padding",
|
||||||
"cum_offsets_out",
|
"cum_offsets_out",
|
||||||
"padding_offset",
|
"batch_id_per_token",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"cu_seqlens_k"})
|
"cu_seqlens_k"})
|
||||||
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
.SetKernelFn(PD_KERNEL(GetPaddingOffset))
|
||||||
|
@@ -5,7 +5,7 @@
|
|||||||
namespace xpu3 {
|
namespace xpu3 {
|
||||||
namespace plugin {
|
namespace plugin {
|
||||||
|
|
||||||
__global__ void get_padding_offset(int *padding_offset,
|
__global__ void get_padding_offset(int *batch_id_per_token,
|
||||||
int *cum_offsets_out,
|
int *cum_offsets_out,
|
||||||
int *cu_seqlens_q,
|
int *cu_seqlens_q,
|
||||||
int *cu_seqlens_k,
|
int *cu_seqlens_k,
|
||||||
@@ -20,7 +20,7 @@ __global__ void get_padding_offset(int *padding_offset,
|
|||||||
int tid = clusterid * ncores + cid;
|
int tid = clusterid * ncores + cid;
|
||||||
|
|
||||||
int buf_len = 32;
|
int buf_len = 32;
|
||||||
__simd__ int padding_offset_lm[buf_len];
|
__simd__ int batch_id_per_token_lm[buf_len];
|
||||||
__simd__ int cum_offsets_lm[16];
|
__simd__ int cum_offsets_lm[16];
|
||||||
int seq_len_lm;
|
int seq_len_lm;
|
||||||
for (int i = clusterid; i < bs; i += nclusters) {
|
for (int i = clusterid; i < bs; i += nclusters) {
|
||||||
@@ -32,11 +32,11 @@ __global__ void get_padding_offset(int *padding_offset,
|
|||||||
for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) {
|
for (int j = cid * buf_len; j < seq_len_lm; j += ncores * buf_len) {
|
||||||
int cur_len = min(seq_len_lm - j, buf_len);
|
int cur_len = min(seq_len_lm - j, buf_len);
|
||||||
for (int k = 0; k < cur_len; k++) {
|
for (int k = 0; k < cur_len; k++) {
|
||||||
padding_offset_lm[k] = cum_offsets_lm[0];
|
batch_id_per_token_lm[k] = i;
|
||||||
}
|
}
|
||||||
mfence_lm();
|
mfence_lm();
|
||||||
LM2GM(padding_offset_lm,
|
LM2GM(batch_id_per_token_lm,
|
||||||
padding_offset + i * max_seq_len - cum_offsets_lm[0] + j,
|
batch_id_per_token + i * max_seq_len - cum_offsets_lm[0] + j,
|
||||||
cur_len * sizeof(int));
|
cur_len * sizeof(int));
|
||||||
}
|
}
|
||||||
if (cid == 0) {
|
if (cid == 0) {
|
||||||
|
@@ -58,7 +58,7 @@ def xpu_pre_process(
|
|||||||
(
|
(
|
||||||
ids_remove_padding,
|
ids_remove_padding,
|
||||||
cum_offsets,
|
cum_offsets,
|
||||||
padding_offset,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
cu_seqlens_k,
|
cu_seqlens_k,
|
||||||
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
|
) = get_padding_offset(input_ids, cum_offsets_now, token_num,
|
||||||
@@ -66,7 +66,7 @@ def xpu_pre_process(
|
|||||||
|
|
||||||
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
share_inputs["ids_remove_padding"] = None # set this after adjust batch
|
||||||
share_inputs["cum_offsets"] = cum_offsets
|
share_inputs["cum_offsets"] = cum_offsets
|
||||||
share_inputs["padding_offset"] = padding_offset
|
share_inputs["batch_id_per_token"] = batch_id_per_token
|
||||||
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
share_inputs["cu_seqlens_q"] = cu_seqlens_q
|
||||||
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
share_inputs["cu_seqlens_k"] = cu_seqlens_k
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ def xpu_pre_process(
|
|||||||
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
seq_lens_decoder=share_inputs["seq_lens_decoder"],
|
||||||
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
seq_lens_this_time=share_inputs["seq_lens_this_time"],
|
||||||
cum_offsets=share_inputs["cum_offsets"],
|
cum_offsets=share_inputs["cum_offsets"],
|
||||||
padding_offset=share_inputs["padding_offset"],
|
batch_id_per_token=share_inputs["batch_id_per_token"],
|
||||||
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
cu_seqlens_q=share_inputs["cu_seqlens_q"],
|
||||||
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
cu_seqlens_k=share_inputs["cu_seqlens_k"],
|
||||||
block_tables=share_inputs["block_tables"],
|
block_tables=share_inputs["block_tables"],
|
||||||
|
Reference in New Issue
Block a user