support mtp in v1_scheduler mode (#3695)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
freeliuzc
2025-09-04 17:39:59 +08:00
committed by GitHub
parent f265a26f8b
commit 88d44a2c93
11 changed files with 909 additions and 316 deletions

View File

@@ -15,7 +15,48 @@
#include "helper.h"
#include "paddle/extension.h"
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
do { \
constexpr int BlockSize = BLOCK_SIZE; \
__VA_ARGS__; \
} while (0)
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
do { \
if (truncate_first_token) { \
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
__VA_ARGS__; \
} else { \
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
do { \
if (kvcache_scheduler_v1) { \
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
__VA_ARGS__; \
} else { \
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
do { \
if (splitwise_prefill) { \
constexpr bool SPLITWISE_PREFILL = true; \
__VA_ARGS__; \
} else { \
constexpr bool SPLITWISE_PREFILL = false; \
__VA_ARGS__; \
} \
} while (0)
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
__global__ void process_splitwise_prefill(
int64_t* draft_tokens,
int64_t* input_ids,
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
__global__ void draft_model_preprocess_kernel(
int64_t* draft_tokens,
int64_t* input_ids,
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
base_model_draft_tokens_now[i] = -1;
}
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
// 1. process block_step situation
// -- In v0 mode, block_step will drop mtp query.
// -- In v1 mode, block_step will continue to infer.
if constexpr(KVCACHE_SCHEDULER_V1) {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
stop_flags[tid] = true;
is_block_step[tid] = true;
// Need to continue infer
}
} else {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
}
// 2. process normal query, not in any special case.
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
// prefill generation
if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder = seq_lens_encoder[tid];
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
int64_t base_model_first_token = accept_tokens_now[0];
pre_ids_now[0] = base_model_first_token;
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else {
} else { // decode generation
if constexpr (KVCACHE_SCHEDULER_V1) {
// 3. try to recover mtp infer in V1 mode
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
is_block_step[tid] = false;
}
}
if (stop_flags[tid]) {
stop_flags[tid] = false;
// TODO: check
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
}
}
template <bool TRCUNCATE_FIRST_TOKEN>
void DispatchRunner(
const cudaStream_t& stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool splitwise_prefill) {
constexpr int BlockSize = 512;
if (splitwise_prefill) {
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
}
}
void DispatchTokenMode(
void DispatchRunner(
const cudaStream_t &stream,
int64_t* draft_tokens,
int64_t* input_ids,
@@ -291,6 +261,7 @@ void DispatchTokenMode(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
@@ -310,75 +281,79 @@ void DispatchTokenMode(
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill) {
if (truncate_first_token) {
DispatchRunner<true>(
stream,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
splitwise_prefill
);
} else {
DispatchRunner<false>(
stream,
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
splitwise_prefill
);
}
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
DISPATCH_BLOCKSIZE(512, {
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
if constexpr (SPLITWISE_PREFILL) {
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
}
});
});
});
});
}
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill) {
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
@@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
DispatchTokenMode(
cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
base_model_stop_flags.data<bool>(),
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
DispatchRunner(
cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
const_cast<bool*>(stop_flags.data<bool>()),
const_cast<int*>(seq_lens_this_time.data<int>()),
const_cast<int*>(seq_lens_encoder.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
accept_num.data<int>(),
base_model_seq_lens_this_time.data<int>(),
base_model_seq_lens_encoder.data<int>(),
base_model_seq_lens_decoder.data<int>(),
base_model_step_idx.data<int64_t>(),
base_model_stop_flags.data<bool>(),
base_model_is_block_step.data<bool>(),
const_cast<int64_t*>(base_model_draft_tokens.data<int64_t>()),
real_bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill,
kvcache_scheduler_v1);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder",
"step_idx",
"not_need_stop",
"is_block_step",
"batch_drop",
"pre_ids",
"accept_tokens",
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"not_need_stop_out",
"batch_drop_out",
"pre_ids_out"})
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},

View File

@@ -0,0 +1,176 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void speculate_schedula_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq) {
const int bid = threadIdx.x;
int stop_flag_now_int = 0;
if (bid < real_bsz) {
if (!stop_flags[bid]) {
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
stop_flag_now_int = 1;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
const int accept_tokens_len = accept_tokens.shape()[1];
const int draft_token_len = draft_tokens.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
constexpr int BlockSize = 512;
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
const_cast<int *>(step_seq_lens_this_time.data<int>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
draft_token_len,
accept_tokens_len,
block_size,
block_num_per_seq
);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
"seq_lens_this_time",
"seq_lens_decoder",
"step_seq_lens_decoder",
"step_draft_tokens",
"step_seq_lens_this_time",
"accept_num",
"accept_tokens",
"is_block_step",
"not_need_stop",
"stop_nums"})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out",
"block_tables_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"step_draft_tokens_out",
"step_seq_lens_this_time_out",
"accept_num_out",
"accept_tokens_out",
"is_block_step_out",
"not_need_stop_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
{"accept_num", "accept_num_out"},
{"accept_tokens", "accept_tokens_out"},
{"is_block_step", "is_block_step_out"},
{"not_need_stop", "not_need_stop_out"},})
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));