support mtp in v1_scheduler mode (#3695)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
freeliuzc
2025-09-04 17:39:59 +08:00
committed by GitHub
parent f265a26f8b
commit 88d44a2c93
11 changed files with 909 additions and 316 deletions

View File

@@ -378,9 +378,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const int block_size);
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens);
paddle::Tensor
GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor,
@@ -707,6 +709,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens,
void SpeculateClearAcceptNums(const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder);
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens);
void NgramMatch(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
@@ -750,6 +768,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
@@ -763,7 +782,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens,
const int max_draft_token,
const bool truncate_first_token,
const bool splitwise_prefill);
const bool splitwise_prefill,
const bool kvcache_scheduler_v1);
void DraftModelUpdate(const paddle::Tensor& inter_next_tokens,
@@ -1228,6 +1248,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function");
m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function");
m.def("ngram_match", &NgramMatch, "ngram_match function");
m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function");

View File

@@ -35,11 +35,52 @@ __global__ void recover_decode_task(bool *stop_flags,
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
}
}
}
}
__global__ void recover_spec_decode_task(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_encoder,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int *block_tables,
bool *is_block_step,
int64_t *draft_tokens,
const int64_t *step_draft_tokens,
const int *step_seq_lens_this_time,
const int bsz,
const int block_num_per_seq,
const int block_size,
const int draft_tokens_len,
const int num_extra_tokens) {
int thread_idx = threadIdx.x;
if (thread_idx < bsz) {
if(is_block_step[thread_idx] == true) {
int *block_table_now = block_tables + thread_idx * block_num_per_seq;
int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size;
max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq);
if (block_table_now[max_possible_block_idx] != -1) {
// can be recovered for decoding
int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len;
const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len;
is_block_step[thread_idx] = false;
seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx];
stop_flags[thread_idx] = false;
seq_lens_encoder[thread_idx] = 0;
seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx];
for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) {
draft_tokens_now[i] = step_draft_tokens_now[i];
}
}
}
}
}
void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
@@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &block_tables,
const paddle::Tensor &is_block_step,
const int block_size) {
const paddle::optional<paddle::Tensor> &draft_tokens,
const paddle::optional<paddle::Tensor> &step_draft_tokens,
const paddle::optional<paddle::Tensor> &step_seq_lens_this_time,
const int block_size,
const int max_draft_tokens) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place()));
auto cu_stream = dev_ctx->stream();
@@ -56,6 +101,26 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
#endif
const int bsz = seq_lens_this_time.shape()[0];
const int block_num_per_seq = block_tables.shape()[1];
if (draft_tokens) {
const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1];
recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_encoder.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<int64_t *>(draft_tokens.get_ptr()->data<int64_t>()),
step_draft_tokens.get_ptr()->data<int64_t>(),
step_seq_lens_this_time.get_ptr()->data<int>(),
bsz,
block_num_per_seq,
block_size,
draft_tokens_len,
max_draft_tokens * 2 + 1);
} else {
recover_decode_task<<<1, 1024, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
@@ -67,6 +132,7 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags,
bsz,
block_num_per_seq,
block_size);
}
}
PD_BUILD_STATIC_OP(recover_decode_task)
@@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task)
"seq_lens_decoder",
"step_seq_lens_decoder",
"block_tables",
"is_block_step"})
.Attrs({"block_size: int"})
"is_block_step",
paddle::Optional("draft_tokens"),
paddle::Optional("step_draft_tokens"),
paddle::Optional("step_seq_lens_this_time")})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"seq_lens_this_time_out",
"seq_lens_encoder_out",
"seq_lens_decoder_out",

View File

@@ -15,7 +15,48 @@
#include "helper.h"
#include "paddle/extension.h"
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \
do { \
constexpr int BlockSize = BLOCK_SIZE; \
__VA_ARGS__; \
} while (0)
#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \
do { \
if (truncate_first_token) { \
constexpr bool TRUNCATE_FIRST_TOKEN = true; \
__VA_ARGS__; \
} else { \
constexpr bool TRUNCATE_FIRST_TOKEN = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \
do { \
if (kvcache_scheduler_v1) { \
constexpr bool KVCACHE_SCHEDULER_V1 = true; \
__VA_ARGS__; \
} else { \
constexpr bool KVCACHE_SCHEDULER_V1 = false; \
__VA_ARGS__; \
} \
} while (0)
#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \
do { \
if (splitwise_prefill) { \
constexpr bool SPLITWISE_PREFILL = true; \
__VA_ARGS__; \
} else { \
constexpr bool SPLITWISE_PREFILL = false; \
__VA_ARGS__; \
} \
} while (0)
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
__global__ void process_splitwise_prefill(
int64_t* draft_tokens,
int64_t* input_ids,
@@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
@@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill(
stop_flags[tid] = false;
int64_t base_model_first_token = accept_tokens_now[0];
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
@@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill(
template <int THREADBLOCK_SIZE, bool TRCUNCATE_FIRST_TOKEN>
template <int THREADBLOCK_SIZE, bool TRUNCATE_FIRST_TOKEN, bool KVCACHE_SCHEDULER_V1>
__global__ void draft_model_preprocess_kernel(
int64_t* draft_tokens,
int64_t* input_ids,
@@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
@@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel(
base_model_draft_tokens_now[i] = -1;
}
// 1. process block_step situation
// -- In v0 mode, block_step will drop mtp query.
// -- In v1 mode, block_step will continue to infer.
if constexpr(KVCACHE_SCHEDULER_V1) {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
stop_flags[tid] = true;
is_block_step[tid] = true;
// Need to continue infer
}
} else {
if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) {
batch_drop[tid] = true;
stop_flags[tid] = true;
}
}
// 2. process normal query, not in any special case.
if (!(base_model_stop_flags[tid] || batch_drop[tid])) {
not_stop_flag = 1;
// 1. first token
// prefill generation
if (seq_lens_encoder[tid] > 0) {
// Can be extended to first few tokens
int seq_len_encoder = seq_lens_encoder[tid];
@@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel(
int64_t base_model_first_token = accept_tokens_now[0];
pre_ids_now[0] = base_model_first_token;
int position = seq_len_encoder;
if (TRCUNCATE_FIRST_TOKEN) {
if (TRUNCATE_FIRST_TOKEN) {
input_ids_now[position - 1] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder;
} else {
input_ids_now[position] = base_model_first_token;
seq_lens_this_time[tid] = seq_len_encoder + 1;
}
} else {
} else { // decode generation
if constexpr (KVCACHE_SCHEDULER_V1) {
// 3. try to recover mtp infer in V1 mode
if (!base_model_is_block_step[tid] && is_block_step[tid]) {
is_block_step[tid] = false;
}
}
if (stop_flags[tid]) {
stop_flags[tid] = false;
// TODO: check
@@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel(
}
}
template <bool TRCUNCATE_FIRST_TOKEN>
void DispatchRunner(
const cudaStream_t& stream,
int64_t* draft_tokens,
int64_t* input_ids,
bool* stop_flags,
int* seq_lens_this_time,
int* seq_lens_encoder,
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
const int* accept_num,
const int* base_model_seq_lens_this_time,
const int* base_model_seq_lens_encoder,
const int* base_model_seq_lens_decoder,
const int64_t* base_model_step_idx,
const bool* base_model_stop_flags,
const bool* base_model_is_block_step,
int64_t* base_model_draft_tokens,
const int bsz,
const int num_model_step,
const int accept_tokens_len,
const int draft_tokens_len,
const int input_ids_len,
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool splitwise_prefill) {
constexpr int BlockSize = 512;
if (splitwise_prefill) {
process_splitwise_prefill<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
} else {
draft_model_preprocess_kernel<BlockSize, TRCUNCATE_FIRST_TOKEN>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
step_idx,
not_need_stop,
batch_drop,
pre_ids,
accept_tokens,
accept_num,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
base_model_seq_lens_decoder,
base_model_step_idx,
base_model_stop_flags,
base_model_is_block_step,
base_model_draft_tokens,
bsz,
num_model_step,
accept_tokens_len,
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len);
}
}
void DispatchTokenMode(
void DispatchRunner(
const cudaStream_t &stream,
int64_t* draft_tokens,
int64_t* input_ids,
@@ -291,6 +261,7 @@ void DispatchTokenMode(
int* seq_lens_decoder,
int64_t* step_idx,
bool* not_need_stop,
bool* is_block_step,
bool* batch_drop,
int64_t* pre_ids,
const int64_t* accept_tokens,
@@ -310,10 +281,15 @@ void DispatchTokenMode(
const int base_model_draft_tokens_len,
const int pre_ids_len,
const bool truncate_first_token,
const bool splitwise_prefill) {
if (truncate_first_token) {
DispatchRunner<true>(
stream,
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
DISPATCH_BLOCKSIZE(512, {
DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, {
DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, {
DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, {
if constexpr (SPLITWISE_PREFILL) {
process_splitwise_prefill<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
@@ -322,6 +298,7 @@ void DispatchTokenMode(
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
@@ -339,12 +316,10 @@ void DispatchTokenMode(
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
splitwise_prefill
);
pre_ids_len);
} else {
DispatchRunner<false>(
stream,
draft_model_preprocess_kernel<BlockSize, TRUNCATE_FIRST_TOKEN, KVCACHE_SCHEDULER_V1>
<<<1, BlockSize, 0, stream>>>(
draft_tokens,
input_ids,
stop_flags,
@@ -353,6 +328,7 @@ void DispatchTokenMode(
seq_lens_decoder,
step_idx,
not_need_stop,
is_block_step,
batch_drop,
pre_ids,
accept_tokens,
@@ -370,15 +346,14 @@ void DispatchTokenMode(
draft_tokens_len,
input_ids_len,
base_model_draft_tokens_len,
pre_ids_len,
splitwise_prefill
);
pre_ids_len);
}
});
});
});
});
}
void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& input_ids,
const paddle::Tensor& stop_flags,
@@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& step_idx,
const paddle::Tensor& not_need_stop,
const paddle::Tensor& is_block_step,
const paddle::Tensor& batch_drop,
const paddle::Tensor& pre_ids,
const paddle::Tensor& accept_tokens,
@@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const paddle::Tensor& base_model_draft_tokens,
const int num_model_step,
const bool truncate_first_token,
const bool splitwise_prefill) {
const bool splitwise_prefill,
const bool kvcache_scheduler_v1) {
int real_bsz = seq_lens_this_time.shape()[0];
int accept_tokens_len = accept_tokens.shape()[1];
int input_ids_len = input_ids.shape()[1];
@@ -412,7 +389,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
auto not_need_stop_gpu =
not_need_stop.copy_to(seq_lens_this_time.place(), false);
DispatchTokenMode(
DispatchRunner(
cu_stream,
const_cast<int64_t*>(draft_tokens.data<int64_t>()),
const_cast<int64_t*>(input_ids.data<int64_t>()),
@@ -422,6 +399,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
const_cast<int*>(seq_lens_decoder.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<bool*>(not_need_stop_gpu.data<bool>()),
const_cast<bool*>(is_block_step.data<bool>()),
const_cast<bool*>(batch_drop.data<bool>()),
const_cast<int64_t*>(pre_ids.data<int64_t>()),
accept_tokens.data<int64_t>(),
@@ -441,7 +419,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens,
base_model_draft_tokens_len,
pre_ids_len,
truncate_first_token,
splitwise_prefill);
splitwise_prefill,
kvcache_scheduler_v1);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), false);
@@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"seq_lens_decoder",
"step_idx",
"not_need_stop",
"is_block_step",
"batch_drop",
"pre_ids",
"accept_tokens",
@@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess)
"not_need_stop_out",
"batch_drop_out",
"pre_ids_out"})
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"})
.Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"input_ids", "input_ids_out"},
{"stop_flags", "stop_flags_out"},

View File

@@ -0,0 +1,176 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "helper.h"
template <int THREADBLOCK_SIZE>
__global__ void speculate_schedula_cache(
const int64_t *draft_tokens,
int *block_tables,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *step_seq_lens_decoder,
int64_t *step_draft_tokens,
int *step_seq_lens_this_time,
int *accept_num,
int64_t *accept_tokens,
bool *is_block_step,
bool *not_need_stop,
const int64_t *stop_nums,
const int real_bsz,
const int max_bsz,
const int max_next_step_tokens,
const int draft_tokens_len,
const int accept_tokens_len,
const int block_size,
const int block_num_per_seq) {
const int bid = threadIdx.x;
int stop_flag_now_int = 0;
if (bid < real_bsz) {
if (!stop_flags[bid]) {
const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len;
int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len;
int *block_table_now = block_tables + bid * block_num_per_seq;
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size;
if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) {
is_block_step[bid] = true;
step_seq_lens_this_time[bid] = seq_lens_this_time[bid];
seq_lens_this_time[bid] = 0;
stop_flags[bid] = true;
stop_flag_now_int = 1;
step_seq_lens_decoder[bid] = seq_lens_decoder[bid];
seq_lens_decoder[bid] = 0;
accept_num[bid] = 0;
for (int i = 0; i < accept_tokens_len; i++) {
accept_tokens_now[i] = -1;
}
for (int i = 0; i < draft_tokens_len; i++) {
step_draft_tokens_now[i] = draft_tokens_now[i];
}
}
} else {
stop_flag_now_int = 1;
}
} else if (bid >= real_bsz && bid < max_bsz) {
stop_flag_now_int = 1;
}
__syncthreads();
typedef cub::BlockReduce<int64_t, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
// printf("stop_flag_now_int %d \n", stop_flag_now_int);
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
if (threadIdx.x == 0) {
// printf("stop_sum %d \n", stop_sum);
not_need_stop[0] = stop_sum < stop_nums[0];
}
}
void SpeculateScheduleCache(const paddle::Tensor &draft_tokens,
const paddle::Tensor &block_tables,
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &step_seq_lens_decoder,
const paddle::Tensor &step_draft_tokens,
const paddle::Tensor &step_seq_lens_this_time,
const paddle::Tensor &accept_num,
const paddle::Tensor &accept_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &not_need_stop,
const paddle::Tensor &stop_nums,
const int block_size,
const int max_draft_tokens) {
const int real_bsz = seq_lens_this_time.shape()[0];
const int max_bsz = stop_flags.shape()[0];
const int accept_tokens_len = accept_tokens.shape()[1];
const int draft_token_len = draft_tokens.shape()[1];
const int block_num_per_seq = block_tables.shape()[1];
constexpr int BlockSize = 512;
const int max_next_step_tokens = 2 * max_draft_tokens + 2;
auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false);
speculate_schedula_cache<BlockSize><<<1, BlockSize, 0, seq_lens_this_time.stream()>>>(
draft_tokens.data<int64_t>(),
const_cast<int *>(block_tables.data<int>()),
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int *>(seq_lens_this_time.data<int>()),
const_cast<int *>(seq_lens_decoder.data<int>()),
const_cast<int *>(step_seq_lens_decoder.data<int>()),
const_cast<int64_t *>(step_draft_tokens.data<int64_t>()),
const_cast<int *>(step_seq_lens_this_time.data<int>()),
const_cast<int *>(accept_num.data<int>()),
const_cast<int64_t *>(accept_tokens.data<int64_t>()),
const_cast<bool *>(is_block_step.data<bool>()),
const_cast<bool *>(not_need_stop_gpu.data<bool>()),
stop_nums.data<int64_t>(),
real_bsz,
max_bsz,
max_next_step_tokens,
draft_token_len,
accept_tokens_len,
block_size,
block_num_per_seq
);
auto not_need_stop_cpu =
not_need_stop_gpu.copy_to(not_need_stop.place(), true);
bool *not_need_stop_data = const_cast<bool *>(not_need_stop.data<bool>());
not_need_stop_data[0] = not_need_stop_cpu.data<bool>()[0];
}
PD_BUILD_STATIC_OP(speculate_schedule_cache)
.Inputs({"draft_tokens",
"block_tables",
"stop_flags",
"seq_lens_this_time",
"seq_lens_decoder",
"step_seq_lens_decoder",
"step_draft_tokens",
"step_seq_lens_this_time",
"accept_num",
"accept_tokens",
"is_block_step",
"not_need_stop",
"stop_nums"})
.Attrs({"block_size: int", "max_draft_tokens: int"})
.Outputs({"draft_tokens_out",
"block_tables_out",
"stop_flags_out",
"seq_lens_this_time_out",
"seq_lens_decoder_out",
"step_seq_lens_decoder_out",
"step_draft_tokens_out",
"step_seq_lens_this_time_out",
"accept_num_out",
"accept_tokens_out",
"is_block_step_out",
"not_need_stop_out"})
.SetInplaceMap({{"draft_tokens", "draft_tokens_out"},
{"block_tables", "block_tables_out"},
{"stop_flags", "stop_flags_out"},
{"seq_lens_this_time", "seq_lens_this_time_out"},
{"seq_lens_decoder", "seq_lens_decoder_out"},
{"step_seq_lens_decoder", "step_seq_lens_decoder_out"},
{"step_draft_tokens", "step_draft_tokens_out"},
{"step_seq_lens_this_time", "step_seq_lens_this_time_out"},
{"accept_num", "accept_num_out"},
{"accept_tokens", "accept_tokens_out"},
{"is_block_step", "is_block_step_out"},
{"not_need_stop", "not_need_stop_out"},})
.SetKernelFn(PD_KERNEL(SpeculateScheduleCache));

View File

@@ -891,7 +891,7 @@ class CacheConfig:
else:
self.kv_cache_ratio = 0.75
self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2
self.prealloc_dec_block_slot_num_threshold = 5
self.prealloc_dec_block_slot_num_threshold = 12
self.cache_dtype = "bfloat16"
self.model_cfg = None
self.enable_chunked_prefill = False

View File

@@ -162,8 +162,7 @@ class EngineArgs:
"""
Ratio of tokens to process in a block.
"""
prealloc_dec_block_slot_num_threshold: int = 5
prealloc_dec_block_slot_num_threshold: int = 12
"""
Token slot threshold for preallocating decoder blocks.
"""
@@ -693,7 +692,7 @@ class EngineArgs:
cache_group.add_argument(
"--prealloc-dec-block-slot-num-threshold",
type=int,
default=5,
default=12,
help="Number of token slot threadshold to allocate next blocks for decoding.",
)

View File

@@ -99,10 +99,14 @@ class ResourceManagerV1(ResourceManager):
def get_new_block_nums(self, request: Request, num_new_tokens: int):
self.check_and_free_block_tables()
return (
block_num = (
request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1
) // self.config.cache_config.block_size - len(request.block_tables)
if self.config.speculative_config.method is not None:
block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq)
return block_num
def _prepare_prefill_task(self, request, new_token_num):
request.prefill_start_index = request.num_computed_tokens
request.prefill_end_index = request.num_computed_tokens + new_token_num

View File

@@ -331,7 +331,9 @@ def post_process_normal(
)
def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False):
def post_process_specualate(
model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False
):
""""""
speculate_update(
model_output.seq_lens_encoder,

View File

@@ -19,8 +19,10 @@ from typing import List
import numpy as np
import paddle
from paddleformers.utils.log import logger
from fastdeploy.engine.request import Request
from fastdeploy import envs
from fastdeploy.engine.request import Request, RequestType
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
@@ -50,14 +52,14 @@ class MTPProposer(Proposer):
Proposer for Multi-Token-Prediction(MTP)
"""
def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs):
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
super().__init__(cfg)
self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank
self.device_id = device_id
self._update_cfg(main_model)
self._load_model()
self.main_model_inputs = main_model_inputs
self.target_model_inputs = target_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
@@ -199,14 +201,16 @@ class MTPProposer(Proposer):
encoder_block_shape_q = 64
decoder_block_shape_q = 16
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.main_model_inputs["decoder_tile_ids_per_batch"]
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.main_model_inputs["decoder_num_blocks_cpu"]
self.target_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
self.target_model_inputs["max_len_tensor_cpu"]
).cpu()
# Get the attention backend
attn_cls = get_attention_backend()
@@ -265,24 +269,24 @@ class MTPProposer(Proposer):
"""
self.model_inputs = {}
# Same shape/dytpe with base model
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"])
self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"])
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"])
self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"])
self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"])
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"])
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"])
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"])
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
self.main_model_inputs["decoder_tile_ids_per_batch"]
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -294,22 +298,22 @@ class MTPProposer(Proposer):
)
# self.model_inputs["caches"] = self.cache_kvs
# Inherit generation hyperparameters from the main model for consistency
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
self.model_inputs["top_k"] = self.main_model_inputs["top_k"]
self.model_inputs["temperature"] = self.main_model_inputs["temperature"]
self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"]
self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"]
self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"]
self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"]
self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"]
self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"]
self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"]
self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"]
self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"]
self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"]
self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"]
self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"]
self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"]
self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"]
self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"]
self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"]
# Integrate the updated results in model forward
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0
# Declare AttentionBackend buffers
@@ -323,7 +327,7 @@ class MTPProposer(Proposer):
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
)
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"])
self.free_list = list(
range(
@@ -337,14 +341,77 @@ class MTPProposer(Proposer):
self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32")
self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32")
self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
if self.num_model_steps > 1:
self.last_seq_lens_this_time = paddle.full_like(
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
)
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
# has_prefill_task = False
# has_decode_task = False
for i in range(req_len):
request = req_dicts[i]
logger.info(f"{i}th request-{request.request_id}: {request}")
idx = request.idx
if request.task_type.value == RequestType.PREFILL.value: # prefill task
prefill_start_index = request.prefill_start_index
prefill_end_index = request.prefill_end_index
length = prefill_end_index - prefill_start_index
input_ids = request.prompt_token_ids + request.output_token_ids
self.input_ids_len[idx] = length
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
self.model_inputs["stop_flags"][idx : idx + 1] = False
self.model_inputs["batch_drop"][idx : idx + 1] = False
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
self.seq_lens_this_time_buffer[idx : idx + 1] = length
self.model_inputs["step_idx"][idx : idx + 1] = (
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
# has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
encoder_block_num = len(request.block_tables)
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
request.block_tables, dtype="int32"
)
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
# has_decode_task = True
# continue
else:
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
self.model_inputs["stop_flags"][idx : idx + 1] = True
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.model_inputs["is_block_step"][idx : idx + 1] = False
continue
# if has_prefill_task or has_decode_task:
# self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
"""
Process inputs for prefill tasks and insert it to model_inputs buffer
@@ -397,9 +464,9 @@ class MTPProposer(Proposer):
length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.cache_config.enable_chunked_prefill:
@@ -455,6 +522,7 @@ class MTPProposer(Proposer):
"""
Prepare MTP inputs
"""
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
draft_model_preprocess(
self.model_inputs["draft_tokens"],
self.model_inputs["input_ids"],
@@ -465,19 +533,21 @@ class MTPProposer(Proposer):
self.model_inputs["step_idx"],
self.model_inputs["not_need_stop"],
self.model_inputs["batch_drop"],
self.model_inputs["is_block_step"],
self.model_inputs["pre_ids"],
self.main_model_inputs["accept_tokens"],
self.main_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"],
self.main_model_inputs["seq_lens_decoder"],
self.main_model_inputs["step_idx"],
self.main_model_inputs["stop_flags"],
self.main_model_inputs["is_block_step"],
self.main_model_inputs["draft_tokens"],
self.target_model_inputs["accept_tokens"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.target_model_inputs["seq_lens_decoder"],
self.target_model_inputs["step_idx"],
self.target_model_inputs["stop_flags"],
self.target_model_inputs["is_block_step"],
self.target_model_inputs["draft_tokens"],
self.num_model_steps,
self.speculative_method in ["eagle", "mtp"],
self.role == "prefill",
use_v1_cache_scheduler,
)
target_hidden_states = eagle_get_hidden_states(
@@ -486,9 +556,9 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_encoder"],
self.model_inputs["seq_lens_decoder"],
self.model_inputs["stop_flags"],
self.main_model_inputs["accept_num"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"],
self.target_model_inputs["accept_num"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps,
)
if isinstance(target_hidden_states, list):
@@ -658,14 +728,14 @@ class MTPProposer(Proposer):
Allocate/Free block of MPT.
"""
draft_model_postprocess(
self.main_model_inputs["draft_tokens"],
self.main_model_inputs["seq_lens_this_time"],
self.main_model_inputs["seq_lens_encoder"],
self.main_model_inputs["stop_flags"],
self.target_model_inputs["draft_tokens"],
self.target_model_inputs["seq_lens_this_time"],
self.target_model_inputs["seq_lens_encoder"],
self.target_model_inputs["stop_flags"],
)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
mtp_step_paddle(
self.main_model_inputs["stop_flags"],
self.target_model_inputs["stop_flags"],
self.model_inputs["stop_flags"],
self.model_inputs["batch_drop"],
self.model_inputs["seq_lens_this_time"],
@@ -684,15 +754,15 @@ class MTPProposer(Proposer):
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
device = paddle.CUDAPinnedPlace()
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
draft_tokens = self.target_model_inputs["draft_tokens"].cpu()
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
hybrid_mtp_ngram(
self.model_inputs["input_ids"]._copy_to(device, True),
self.input_ids_len,
self.model_inputs["pre_ids"]._copy_to(device, True),
self.model_inputs["step_idx"].cpu(),
self.main_model_inputs["actual_draft_token_num"].cpu(),
self.target_model_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_decoder,
@@ -701,8 +771,8 @@ class MTPProposer(Proposer):
self.min_ngram_size,
self.max_draft_token_num,
)
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
def _run_impl(self, full_hidden_states):
""""""

View File

@@ -58,6 +58,7 @@ else:
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
speculate_schedule_cache,
)
from fastdeploy.model_executor.pre_and_post_process import (
@@ -394,6 +395,8 @@ class GPUModelRunner(ModelRunnerBase):
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
if self.speculative_method in ["mtp"]:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
"""
@@ -815,6 +818,13 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype="int32",
)
# For V1_KVCACHE_SCHEDULER
self.share_inputs["step_draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
)
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
if self.enable_mm:
head_dim = self.model_config.head_dim
@@ -853,7 +863,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["block_tables"],
self.share_inputs["is_block_step"],
self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None,
self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None,
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
)
# Remove padding
@@ -1556,6 +1570,24 @@ class GPUModelRunner(ModelRunnerBase):
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
elif self.speculative_decoding:
speculate_schedule_cache(
self.share_inputs["draft_tokens"],
self.share_inputs["block_tables"],
self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["step_draft_tokens"],
self.share_inputs["step_seq_lens_this_time"],
self.share_inputs["accept_num"],
self.share_inputs["accept_tokens"],
self.share_inputs["is_block_step"],
self.share_inputs["not_need_stop"],
self.share_inputs["stop_nums"],
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens,
)
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False

View File

@@ -0,0 +1,239 @@
import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache
def cpu_reference(
draft_tokens,
block_tables,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
block_size,
max_draft_tokens,
):
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
Shapes are the same as inputs to the custom op. This mutates the provided
NumPy arrays in-place, exactly like the kernel does.
"""
real_bsz = seq_lens_this_time.shape[0]
max_bsz = stop_flags.shape[0]
draft_tokens_len = draft_tokens.shape[1]
block_num_per_seq = block_tables.shape[1]
max_next_step_tokens = 2 * max_draft_tokens + 2
# Block-local reduction input per thread (threadIdx.x -> bid)
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
for bid in range(512):
if bid < real_bsz:
if not stop_flags[bid]:
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
is_block_step[bid] = True
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
seq_lens_this_time[bid] = 0
stop_flags[bid] = True
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
seq_lens_decoder[bid] = 0
accept_num[bid] = 0
accept_tokens[bid, :] = -1
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
else:
stop_flag_now_int[bid] = 1
elif bid < max_bsz:
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
stop_sum = int(stop_flag_now_int.sum())
not_need_stop[0] = stop_sum < int(stop_nums[0])
class TestSpeculateScheduleCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not paddle.is_compiled_with_cuda():
raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.")
paddle.device.set_device("gpu")
def setUp(self):
# --- Construct a deterministic case that exercises all branches ---
# real_bsz < max_bsz to test the padding logic in the CUB reduction
self.real_bsz = 3
self.max_bsz = 5 # only stop_flags has length max_bsz
self.draft_tokens_len = 6
self.accept_tokens_len = 5
self.block_size = 4
self.block_num_per_seq = 3
self.max_draft_tokens = 2 # -> max_next_step_tokens = 6
# Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped
# seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2
self.draft_tokens = paddle.to_tensor(
np.array(
[
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
],
dtype=np.int64,
)
)
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
# stop_flags length is max_bsz, others are real_bsz
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
# Will be filled by kernel for the triggering bids only
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
# Intentionally non-zero so we can verify in-place zeroing only where triggered
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
self.accept_tokens = paddle.to_tensor(
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
self.real_bsz, self.accept_tokens_len
)
)
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
# Set stop_nums to 5 so not_need_stop = (4 < 5) = True
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
# Keep NumPy copies for CPU reference
self.np_draft_tokens = self.draft_tokens.numpy().copy()
self.np_block_tables = self.block_tables.numpy().copy()
self.np_stop_flags = self.stop_flags.numpy().copy()
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
self.np_accept_num = self.accept_num.numpy().copy()
self.np_accept_tokens = self.accept_tokens.numpy().copy()
self.np_is_block_step = self.is_block_step.numpy().copy()
self.np_not_need_stop = self.not_need_stop.numpy().copy()
self.np_stop_nums = self.stop_nums.numpy().copy()
def test_correctness_against_cpu_reference(self):
# Run GPU kernel (in-place)
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compute CPU reference (in-place on NumPy copies)
cpu_reference(
self.np_draft_tokens,
self.np_block_tables,
self.np_stop_flags,
self.np_seq_lens_this_time,
self.np_seq_lens_decoder,
self.np_step_seq_lens_decoder,
self.np_step_draft_tokens,
self.np_step_seq_lens_this_time,
self.np_accept_num,
self.np_accept_tokens,
self.np_is_block_step,
self.np_not_need_stop,
self.np_stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compare all mutated tensors
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
def test_no_trigger_path(self):
# Make block_tables at candidate index != -1 so nothing triggers
# Candidate index for bid 0/1 is 1, set it to 7
bt = self.block_tables.numpy()
bt[:, 1] = 7
self.block_tables = paddle.to_tensor(bt)
# Reset outputs to distinctive values
self.step_seq_lens_decoder[:] = 0
self.step_draft_tokens[:] = 0
self.step_seq_lens_this_time[:] = 0
self.accept_num[:] = -123
self.accept_tokens[:] = -777
self.is_block_step[:] = False
self.not_need_stop[:] = False
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
# With stop_nums=5 -> True
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Nothing should have changed except not_need_stop
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
if __name__ == "__main__":
unittest.main()