[Code Simplification] remove cum_offsets (#3410)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled

This commit is contained in:
lizexu123
2025-08-18 20:21:25 +08:00
committed by GitHub
parent 2cf96ddd68
commit 32b39620bc
9 changed files with 73 additions and 87 deletions

View File

@@ -84,7 +84,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
seq_length, seq_length,
bsz); bsz);
return {x_remove_padding, return {x_remove_padding,
cum_offsets_out,
padding_offset, padding_offset,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k}; cu_seqlens_k};
@@ -97,7 +96,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) { const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0]; int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1]; int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
} }
std::vector<paddle::DataType> GetPaddingOffsetInferDtype( std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -106,7 +105,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype, const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) { const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype, return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype, seq_len_dtype,
seq_len_dtype, seq_len_dtype,
seq_len_dtype}; seq_len_dtype};
@@ -115,7 +113,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset_cpu) PD_BUILD_STATIC_OP(get_padding_offset_cpu)
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"}) .Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
.Outputs({"x_remove_padding", .Outputs({"x_remove_padding",
"cum_offsets_out",
"padding_offset", "padding_offset",
"cu_seqlens_q", "cu_seqlens_q",
"cu_seqlens_k"}) "cu_seqlens_k"})

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@@ -19,10 +19,11 @@
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif #endif
template <typename T> template <typename T>
void RebuildPaddingCPUImpl(T *output_data, void RebuildPaddingCPUImpl(T *output_data,
const T *input_data, const T *input_data,
const int *cum_offsets_data, const int *cu_seqlens_q_data,
const int *seq_len_this_time_data, const int *seq_len_this_time_data,
const int *seq_lens_decoder_data, const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data, const int *seq_lens_encoder_data,
@@ -40,11 +41,12 @@ void RebuildPaddingCPUImpl(T *output_data,
if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) { if (seq_lens_decoder_data[bi] == 0 && seq_lens_encoder_data[bi] == 0) {
continue; continue;
} }
if (seq_lens_encoder_data[bi] > 0) { if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1; seq_id = seq_lens_encoder_data[bi] - 1;
} }
const int ori_token_idx =
bi * max_input_length - cum_offsets_data[bi] + seq_id; const int ori_token_idx = cu_seqlens_q_data[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx; const int src_offset = ori_token_idx * dim_embed + bias_idx;
output_data[i] = input_data[src_offset]; output_data[i] = input_data[src_offset];
@@ -54,7 +56,7 @@ void RebuildPaddingCPUImpl(T *output_data,
template <typename T> template <typename T>
void RebuildAppendPaddingCPUImpl(T *output_data, void RebuildAppendPaddingCPUImpl(T *output_data,
const T *input_data, const T *input_data,
const int *cum_offsets_data, const int *cu_seqlens_q_data,
const int *seq_len_this_time_data, const int *seq_len_this_time_data,
const int *seq_lens_decoder_data, const int *seq_lens_decoder_data,
const int *seq_lens_encoder_data, const int *seq_lens_encoder_data,
@@ -73,26 +75,28 @@ void RebuildAppendPaddingCPUImpl(T *output_data,
continue; continue;
} }
int seq_id = 0; int seq_id = 0;
if (seq_lens_encoder_data[bi] > 0) { if (seq_lens_encoder_data[bi] > 0) {
seq_id = seq_lens_encoder_data[bi] - 1; seq_id = seq_lens_encoder_data[bi] - 1;
} }
int input_token_id = ori_token_id - cum_offsets_data[bi] + seq_id; int input_token_id = cu_seqlens_q_data[bi] + seq_id;
int bias_idx = i % dim_embed; int bias_idx = i % dim_embed;
int src_offset = input_token_id * dim_embed + bias_idx; int src_offset = input_token_id * dim_embed + bias_idx;
output_data[i] = input_data[src_offset]; output_data[i] = input_data[src_offset];
} }
} }
std::vector<paddle::Tensor> RebuildPaddingCPU( std::vector<paddle::Tensor> RebuildPaddingCPU(
const paddle::Tensor &tmp_out, const paddle::Tensor &tmp_out,
const paddle::Tensor &cum_offsets, const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset, const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) { int max_input_length) {
auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true); auto tmp_out_cpu = tmp_out.copy_to(paddle::CPUPlace(), true);
auto cum_offsets_cpu = cum_offsets.copy_to(paddle::CPUPlace(), true); auto cu_seqlens_q_cpu = cu_seqlens_q.copy_to(paddle::CPUPlace(), true);
auto seq_len_this_time_cpu = auto seq_len_this_time_cpu =
seq_len_this_time.copy_to(paddle::CPUPlace(), true); seq_len_this_time.copy_to(paddle::CPUPlace(), true);
auto seq_lens_decoder_cpu = auto seq_lens_decoder_cpu =
@@ -107,7 +111,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
int token_num = tmp_out_cpu.shape()[0]; int token_num = tmp_out_cpu.shape()[0];
int dim_embed = tmp_out_cpu.shape()[1]; int dim_embed = tmp_out_cpu.shape()[1];
int bsz = cum_offsets_cpu.shape()[0]; int bsz = cu_seqlens_q_cpu.shape()[0] - 1;
paddle::Tensor out; paddle::Tensor out;
if (output_padding_offset_cpu) { if (output_padding_offset_cpu) {
@@ -128,7 +132,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
{bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace()); {bsz, dim_embed}, 0, tmp_out_cpu.dtype(), paddle::CPUPlace());
} }
const int *cum_offsets_data = cum_offsets_cpu.data<int>(); const int *cu_seqlens_q_data = cu_seqlens_q_cpu.data<int>();
const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>(); const int *seq_len_this_time_data = seq_len_this_time_cpu.data<int>();
const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>(); const int *seq_lens_decoder_data = seq_lens_decoder_cpu.data<int>();
const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>(); const int *seq_lens_encoder_data = seq_lens_encoder_cpu.data<int>();
@@ -141,7 +145,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32: case paddle::DataType::FLOAT32:
RebuildAppendPaddingCPUImpl<float>(out.data<float>(), RebuildAppendPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(), tmp_out_cpu.data<float>(),
cum_offsets_data, cu_seqlens_q_data,
seq_len_this_time_data, seq_len_this_time_data,
seq_lens_decoder_data, seq_lens_decoder_data,
seq_lens_encoder_data, seq_lens_encoder_data,
@@ -154,7 +158,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::float16>( RebuildAppendPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(), out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(), tmp_out_cpu.data<paddle::float16>(),
cum_offsets_data, cu_seqlens_q_data,
seq_len_this_time_data, seq_len_this_time_data,
seq_lens_decoder_data, seq_lens_decoder_data,
seq_lens_encoder_data, seq_lens_encoder_data,
@@ -167,7 +171,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildAppendPaddingCPUImpl<paddle::bfloat16>( RebuildAppendPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(), out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(), tmp_out_cpu.data<paddle::bfloat16>(),
cum_offsets_data, cu_seqlens_q_data,
seq_len_this_time_data, seq_len_this_time_data,
seq_lens_decoder_data, seq_lens_decoder_data,
seq_lens_encoder_data, seq_lens_encoder_data,
@@ -186,7 +190,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
case paddle::DataType::FLOAT32: case paddle::DataType::FLOAT32:
RebuildPaddingCPUImpl<float>(out.data<float>(), RebuildPaddingCPUImpl<float>(out.data<float>(),
tmp_out_cpu.data<float>(), tmp_out_cpu.data<float>(),
cum_offsets_data, cu_seqlens_q_data,
seq_len_this_time_data, seq_len_this_time_data,
seq_lens_decoder_data, seq_lens_decoder_data,
seq_lens_encoder_data, seq_lens_encoder_data,
@@ -198,7 +202,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
RebuildPaddingCPUImpl<paddle::float16>( RebuildPaddingCPUImpl<paddle::float16>(
out.data<paddle::float16>(), out.data<paddle::float16>(),
tmp_out_cpu.data<paddle::float16>(), tmp_out_cpu.data<paddle::float16>(),
cum_offsets_data, cu_seqlens_q_data,
seq_len_this_time_data, seq_len_this_time_data,
seq_lens_decoder_data, seq_lens_decoder_data,
seq_lens_encoder_data, seq_lens_encoder_data,
@@ -207,11 +211,10 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
elem_nums); elem_nums);
break; break;
case paddle::DataType::BFLOAT16: case paddle::DataType::BFLOAT16:
RebuildPaddingCPUImpl<paddle::bfloat16>( RebuildPaddingCPUImpl<paddle::bfloat16>(
out.data<paddle::bfloat16>(), out.data<paddle::bfloat16>(),
tmp_out_cpu.data<paddle::bfloat16>(), tmp_out_cpu.data<paddle::bfloat16>(),
cum_offsets_data, cu_seqlens_q_data,
seq_len_this_time_data, seq_len_this_time_data,
seq_lens_decoder_data, seq_lens_decoder_data,
seq_lens_encoder_data, seq_lens_encoder_data,
@@ -230,7 +233,7 @@ std::vector<paddle::Tensor> RebuildPaddingCPU(
std::vector<std::vector<int64_t>> RebuildPaddingInferShape( std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &tmp_out_shape, const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cum_offsets_shape, const std::vector<int64_t> &cu_seqlens_q_shape,
const std::vector<int64_t> &seq_len_this_time_shape, const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape, const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape, const std::vector<int64_t> &seq_lens_encoder_shape,
@@ -239,14 +242,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
if (output_padding_offset_shape) { if (output_padding_offset_shape) {
return {{-1, dim_embed}}; return {{-1, dim_embed}};
} else { } else {
int64_t bsz = cum_offsets_shape[0]; int64_t bsz = cu_seqlens_q_shape[0] - 1;
return {{bsz, dim_embed}}; return {{bsz, dim_embed}};
} }
} }
std::vector<paddle::DataType> RebuildPaddingInferDtype( std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &tmp_out_dtype, const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cum_offsets_dtype, const paddle::DataType &cu_seqlens_q_dtype,
const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype, const paddle::DataType &seq_lens_encoder_dtype,
@@ -256,7 +259,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
PD_BUILD_STATIC_OP(rebuild_padding_cpu) PD_BUILD_STATIC_OP(rebuild_padding_cpu)
.Inputs({"tmp_out", .Inputs({"tmp_out",
"cum_offsets", "cu_seqlens_q",
"seq_len_this_time", "seq_len_this_time",
"seq_lens_decoder", "seq_lens_decoder",
"seq_lens_encoder", "seq_lens_encoder",

View File

@@ -101,7 +101,6 @@ std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
cum_offsets_out.data<int>(), cum_offsets_out.data<int>(),
seq_length); seq_length);
return {x_remove_padding, return {x_remove_padding,
cum_offsets_out,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num}; cu_seqlens_k}; // , enc_token_num, dec_token_num};
@@ -114,7 +113,7 @@ std::vector<std::vector<int64_t>> GetPaddingOffsetInferShape(
const std::vector<int64_t> &seq_len_shape) { const std::vector<int64_t> &seq_len_shape) {
int64_t bsz = seq_len_shape[0]; int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1]; int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
} }
std::vector<paddle::DataType> GetPaddingOffsetInferDtype( std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
@@ -123,7 +122,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
const paddle::DataType &token_num_dtype, const paddle::DataType &token_num_dtype,
const paddle::DataType &seq_len_dtype) { const paddle::DataType &seq_len_dtype) {
return {input_ids_dtype, return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype, seq_len_dtype,
seq_len_dtype, seq_len_dtype,
seq_len_dtype}; seq_len_dtype};
@@ -132,7 +130,6 @@ std::vector<paddle::DataType> GetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(get_padding_offset) PD_BUILD_STATIC_OP(get_padding_offset)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Outputs({"x_remove_padding", .Outputs({"x_remove_padding",
"cum_offsets_out",
"batch_id_per_token", "batch_id_per_token",
"cu_seqlens_q", "cu_seqlens_q",
"cu_seqlens_k"}) "cu_seqlens_k"})

View File

@@ -17,7 +17,7 @@
template <typename T, int VecSize> template <typename T, int VecSize>
__global__ void RebuildPaddingKernel(T *output_data, __global__ void RebuildPaddingKernel(T *output_data,
const T *input_data, const T *input_data,
const int *cum_offsets, const int *cu_seqlens_q,
const int *seq_len_this_time, const int *seq_len_this_time,
const int *seq_len_decoder, const int *seq_len_decoder,
const int *seq_len_encoder, const int *seq_len_encoder,
@@ -34,10 +34,10 @@ __global__ void RebuildPaddingKernel(T *output_data,
int seq_id = 0; int seq_id = 0;
if (seq_len_this_time[bi] == 0) continue; if (seq_len_this_time[bi] == 0) continue;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
// if encoder, get last token; just decoder, get first token.
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1; if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int ori_token_idx = const int ori_token_idx =
bi * max_input_length - cum_offsets[bi] + seq_id; cu_seqlens_q[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx; const int src_offset = ori_token_idx * dim_embed + bias_idx;
Load<T, VecSize>(&input_data[src_offset], &src_vec); Load<T, VecSize>(&input_data[src_offset], &src_vec);
Store<T, VecSize>(src_vec, &output_data[i]); Store<T, VecSize>(src_vec, &output_data[i]);
@@ -47,29 +47,31 @@ __global__ void RebuildPaddingKernel(T *output_data,
template <typename T, int VecSize> template <typename T, int VecSize>
__global__ void RebuildAppendPaddingKernel(T *output_data, __global__ void RebuildAppendPaddingKernel(T *output_data,
const T *input_data, const T *input_data,
const int *cum_offset, const int *cu_seqlens_q,
const int *seq_len_this_time, const int *seq_len_this_time,
const int *seq_len_decoder, const int *seq_len_decoder,
const int *seq_len_encoder, const int *seq_len_encoder,
const int *output_padding_offset, const int *output_padding_offset,
const int max_input_length, const int max_input_length,
const int dim_embed, const int dim_embed,
const int64_t output_elem_nums) { const int64_t output_elem_nums,
const int bsz) {
AlignedVector<T, VecSize> src_vec; AlignedVector<T, VecSize> src_vec;
const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; const int64_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int64_t i = global_idx * VecSize; i < output_elem_nums; for (int64_t i = global_idx * VecSize; i < output_elem_nums;
i += gridDim.x * blockDim.x * VecSize) { i += gridDim.x * blockDim.x * VecSize) {
const int out_token_id = i / dim_embed; const int out_token_id = i / dim_embed;
const int ori_token_id = const int ori_token_id = out_token_id + output_padding_offset[out_token_id];
out_token_id + output_padding_offset[out_token_id];
const int bi = ori_token_id / max_input_length; const int bi = ori_token_id / max_input_length;
int seq_id = 0; int seq_id = 0;
if (seq_len_this_time[bi] == 0) continue; if (seq_len_this_time[bi] == 0) continue;
if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue; if (seq_len_decoder[bi] == 0 && seq_len_encoder[bi] == 0) continue;
// if encoder, get last token; just decoder, get first token.
if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int input_token_id = ori_token_id - cum_offset[bi] + seq_id; if (seq_len_encoder[bi] > 0) seq_id = seq_len_encoder[bi] - 1;
const int cum_offset_bi = bi * max_input_length - cu_seqlens_q[bi];
const int input_token_id = ori_token_id - cum_offset_bi + seq_id;
const int bias_idx = i % dim_embed; const int bias_idx = i % dim_embed;
Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx], Load<T, VecSize>(&input_data[input_token_id * dim_embed + bias_idx],
@@ -78,10 +80,11 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,
} }
} }
template <paddle::DataType D> template <paddle::DataType D>
std::vector<paddle::Tensor> rebuild_padding( std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1]
const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
@@ -100,7 +103,7 @@ std::vector<paddle::Tensor> rebuild_padding(
std::vector<int64_t> tmp_out_shape = tmp_out.shape(); std::vector<int64_t> tmp_out_shape = tmp_out.shape();
const int token_num = tmp_out_shape[0]; const int token_num = tmp_out_shape[0];
const int dim_embed = tmp_out_shape[1]; const int dim_embed = tmp_out_shape[1];
const int bsz = cum_offsets.shape()[0]; const int bsz = cu_seqlens_q.shape()[0] - 1;
paddle::Tensor out; paddle::Tensor out;
if (output_padding_offset) { if (output_padding_offset) {
@@ -133,21 +136,22 @@ std::vector<paddle::Tensor> rebuild_padding(
<<<grid_size, blocksize, 0, cu_stream>>>( <<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()), reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()), reinterpret_cast<const DataType_ *>(tmp_out.data<data_t>()),
cum_offsets.data<int>(), cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(), seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(), seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
output_padding_offset.get_ptr()->data<int>(), output_padding_offset.get_ptr()->data<int>(),
max_input_length, max_input_length,
dim_embed, dim_embed,
elem_nums); elem_nums,
bsz);
} else { } else {
RebuildPaddingKernel<DataType_, PackSize> RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>( <<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()), reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<DataType_ *>( reinterpret_cast<DataType_ *>(
const_cast<data_t *>(tmp_out.data<data_t>())), const_cast<data_t *>(tmp_out.data<data_t>())),
cum_offsets.data<int>(), cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(), seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(), seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
@@ -160,7 +164,7 @@ std::vector<paddle::Tensor> rebuild_padding(
paddle::Tensor RebuildPaddingFunc( paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1]
const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
@@ -170,7 +174,7 @@ paddle::Tensor RebuildPaddingFunc(
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
return rebuild_padding<paddle::DataType::BFLOAT16>( return rebuild_padding<paddle::DataType::BFLOAT16>(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -180,7 +184,7 @@ paddle::Tensor RebuildPaddingFunc(
case paddle::DataType::FLOAT16: { case paddle::DataType::FLOAT16: {
return rebuild_padding<paddle::DataType::FLOAT16>( return rebuild_padding<paddle::DataType::FLOAT16>(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -190,7 +194,7 @@ paddle::Tensor RebuildPaddingFunc(
case paddle::DataType::FLOAT32: { case paddle::DataType::FLOAT32: {
return rebuild_padding<paddle::DataType::FLOAT32>( return rebuild_padding<paddle::DataType::FLOAT32>(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -208,14 +212,14 @@ paddle::Tensor RebuildPaddingFunc(
std::vector<paddle::Tensor> RebuildPadding( std::vector<paddle::Tensor> RebuildPadding(
const paddle::Tensor &tmp_out, // [token_num, dim_embed] const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cum_offsets, // [bsz, 1] const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1]
const paddle::Tensor &seq_len_this_time, const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset, const paddle::optional<paddle::Tensor> &output_padding_offset,
int max_input_length) { int max_input_length) {
return {RebuildPaddingFunc(tmp_out, return {RebuildPaddingFunc(tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -225,7 +229,7 @@ std::vector<paddle::Tensor> RebuildPadding(
std::vector<std::vector<int64_t>> RebuildPaddingInferShape( std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
const std::vector<int64_t> &tmp_out_shape, const std::vector<int64_t> &tmp_out_shape,
const std::vector<int64_t> &cum_offsets_shape, const std::vector<int64_t> &cu_seqlens_q_shape,
const std::vector<int64_t> &seq_len_this_time_shape, const std::vector<int64_t> &seq_len_this_time_shape,
const std::vector<int64_t> &seq_lens_decoder_shape, const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_encoder_shape, const std::vector<int64_t> &seq_lens_encoder_shape,
@@ -235,14 +239,14 @@ std::vector<std::vector<int64_t>> RebuildPaddingInferShape(
if (output_padding_offset_shape) { if (output_padding_offset_shape) {
return {{-1, dim_embed}}; return {{-1, dim_embed}};
} else { } else {
int64_t bsz = cum_offsets_shape[0]; int64_t bsz = cu_seqlens_q_shape[0] - 1;
return {{bsz, dim_embed}}; return {{bsz, dim_embed}};
} }
} }
std::vector<paddle::DataType> RebuildPaddingInferDtype( std::vector<paddle::DataType> RebuildPaddingInferDtype(
const paddle::DataType &tmp_out_dtype, const paddle::DataType &tmp_out_dtype,
const paddle::DataType &cum_offsets_dtype, const paddle::DataType &cu_seqlens_q_dtype,
const paddle::DataType &seq_len_this_time_dtype, const paddle::DataType &seq_len_this_time_dtype,
const paddle::DataType &seq_lens_decoder_dtype, const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_encoder_dtype, const paddle::DataType &seq_lens_encoder_dtype,
@@ -252,7 +256,7 @@ std::vector<paddle::DataType> RebuildPaddingInferDtype(
PD_BUILD_STATIC_OP(rebuild_padding) PD_BUILD_STATIC_OP(rebuild_padding)
.Inputs({"tmp_out", .Inputs({"tmp_out",
"cum_offsets", "cu_seqlens_q",
"seq_len_this_time", "seq_len_this_time",
"seq_lens_decoder", "seq_lens_decoder",
"seq_lens_encoder", "seq_lens_encoder",

View File

@@ -106,7 +106,6 @@ std::vector<paddle::Tensor> SpeculateGetPaddingOffset(
seq_length, seq_length,
max_draft_tokens); max_draft_tokens);
return {x_remove_padding, return {x_remove_padding,
cum_offsets_out,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k}; // , enc_token_num, dec_token_num}; cu_seqlens_k}; // , enc_token_num, dec_token_num};
@@ -121,7 +120,7 @@ std::vector<std::vector<int64_t>> SpeculateGetPaddingOffsetInferShape(
const std::vector<int64_t>& seq_lens_encoder_shape) { const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = seq_len_shape[0]; int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1]; int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; return {{-1}, {-1}, {bsz + 1}, {bsz + 1}};
} }
std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype( std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
@@ -132,7 +131,6 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
const paddle::DataType& seq_len_dtype, const paddle::DataType& seq_len_dtype,
const paddle::DataType& seq_lens_encoder_dtype) { const paddle::DataType& seq_lens_encoder_dtype) {
return {input_ids_dtype, return {input_ids_dtype,
seq_len_dtype,
seq_len_dtype, seq_len_dtype,
seq_len_dtype, seq_len_dtype,
seq_len_dtype}; seq_len_dtype};
@@ -141,12 +139,10 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(speculate_get_padding_offset) PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids", .Inputs({"input_ids",
"draft_tokens", "draft_tokens",
"cum_offsets",
"token_num", "token_num",
"seq_len", "seq_len",
"seq_lens_encoder"}) "seq_lens_encoder"})
.Outputs({"x_remove_padding", .Outputs({"x_remove_padding",
"cum_offsets_out",
"batch_id_per_token", "batch_id_per_token",
"cu_seqlens_q", "cu_seqlens_q",
"cu_seqlens_k"}) "cu_seqlens_k"})

View File

@@ -112,7 +112,6 @@ def pre_process(
if speculative_decoding: if speculative_decoding:
( (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
@@ -142,14 +141,12 @@ def pre_process(
else: else:
( (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time) ) = get_padding_offset(input_ids, cum_offsets_now, token_num, seq_lens_this_time)
return ( return (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
@@ -523,7 +520,7 @@ def step_cuda(
def rebuild_padding( def rebuild_padding(
tmp_out: paddle.Tensor, tmp_out: paddle.Tensor,
cum_offsets: paddle.Tensor, cu_seqlens_q: paddle.Tensor,
seq_len_this_time: paddle.Tensor, seq_len_this_time: paddle.Tensor,
seq_lens_decoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor,
seq_lens_encoder: paddle.Tensor, seq_lens_encoder: paddle.Tensor,
@@ -539,7 +536,7 @@ def rebuild_padding(
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -551,7 +548,7 @@ def rebuild_padding(
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -563,7 +560,7 @@ def rebuild_padding(
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -575,7 +572,7 @@ def rebuild_padding(
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -587,7 +584,7 @@ def rebuild_padding(
hidden_states = rebuild_padding_cpu( hidden_states = rebuild_padding_cpu(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
@@ -599,7 +596,7 @@ def rebuild_padding(
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
tmp_out, tmp_out,
cum_offsets, cu_seqlens_q,
seq_len_this_time, seq_len_this_time,
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,

View File

@@ -274,7 +274,6 @@ class MTPProposer(Proposer):
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"]) self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"])
self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"])
self.model_inputs["cum_offsets"] = paddle.clone(self.main_model_inputs["cum_offsets"])
self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"])
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"]) self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"])
@@ -530,7 +529,6 @@ class MTPProposer(Proposer):
# Remove padding # Remove padding
( (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
@@ -546,7 +544,6 @@ class MTPProposer(Proposer):
) )
# Initialize forward meta data # Initialize forward meta data
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.model_inputs["cum_offsets"].copy_(cum_offsets, False)
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -581,7 +578,7 @@ class MTPProposer(Proposer):
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.model_inputs["cum_offsets"], self.model_inputs["cu_seqlens_q"],
self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_this_time"],
self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_decoder"],
self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_encoder"],

View File

@@ -423,7 +423,7 @@ class GCUModelRunner(ModelRunnerBase):
0, 0,
dtype="int64", dtype="int64",
) )
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
@@ -522,7 +522,6 @@ class GCUModelRunner(ModelRunnerBase):
) )
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -742,7 +741,7 @@ class GCUModelRunner(ModelRunnerBase):
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cum_offsets"], self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_encoder"],
@@ -967,7 +966,7 @@ class GCUModelRunner(ModelRunnerBase):
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cum_offsets"], self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_encoder"],

View File

@@ -583,7 +583,6 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len self.share_inputs["min_dec_len"][idx : idx + 1] = max_dec_len
self.share_inputs["stop_flags"][idx : idx + 1] = False self.share_inputs["stop_flags"][idx : idx + 1] = False
self.share_inputs["temperature"][idx : idx + 1] = 1 self.share_inputs["temperature"][idx : idx + 1] = 1
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1] self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = input_length
@@ -680,7 +679,6 @@ class GPUModelRunner(ModelRunnerBase):
0, 0,
dtype="int64", dtype="int64",
) )
self.share_inputs["cum_offsets"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["batch_id_per_token"] = paddle.full( self.share_inputs["batch_id_per_token"] = paddle.full(
[max_num_seqs * self.parallel_config.max_model_len, 1], 0, dtype="int32" [max_num_seqs * self.parallel_config.max_model_len, 1], 0, dtype="int32"
) )
@@ -803,7 +801,6 @@ class GPUModelRunner(ModelRunnerBase):
# Remove padding # Remove padding
( (
ids_remove_padding, ids_remove_padding,
cum_offsets,
batch_id_per_token, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
cu_seqlens_k, cu_seqlens_k,
@@ -819,7 +816,6 @@ class GPUModelRunner(ModelRunnerBase):
) )
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.share_inputs["cum_offsets"].copy_(cum_offsets, False)
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -965,7 +961,6 @@ class GPUModelRunner(ModelRunnerBase):
cache_kvs_list.append(value_cache) cache_kvs_list.append(value_cache)
self.share_inputs["caches"] = cache_kvs_list self.share_inputs["caches"] = cache_kvs_list
else: else:
for i in range(self.model_config.num_hidden_layers): for i in range(self.model_config.num_hidden_layers):
cache_kvs[f"key_caches_{i}"] = paddle.full( cache_kvs[f"key_caches_{i}"] = paddle.full(
@@ -1071,7 +1066,7 @@ class GPUModelRunner(ModelRunnerBase):
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cum_offsets"], self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_encoder"],
@@ -1336,7 +1331,7 @@ class GPUModelRunner(ModelRunnerBase):
) )
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cum_offsets"], self.share_inputs["cu_seqlens_q"],
self.share_inputs["seq_lens_this_time"], self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"], self.share_inputs["seq_lens_decoder"],
self.share_inputs["seq_lens_encoder"], self.share_inputs["seq_lens_encoder"],
@@ -1436,6 +1431,7 @@ class GPUModelRunner(ModelRunnerBase):
# 7. Updata 'infer_seed' and step_cuda() # 7. Updata 'infer_seed' and step_cuda()
self.share_inputs["infer_seed"].add_(self.infer_seed_increment) self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
if not envs.ENABLE_V1_KVCACHE_SCHEDULER: if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
step_cuda( step_cuda(
self.share_inputs, self.share_inputs,