diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 9f0cc3e73..1a3588491 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -378,9 +378,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags, const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &block_tables, const paddle::Tensor &is_block_step, - const int block_size); - - + const paddle::optional &draft_tokens, + const paddle::optional &step_draft_tokens, + const paddle::optional &step_seq_lens_this_time, + const int block_size, + const int max_draft_tokens); paddle::Tensor GroupSwigluWithMasked(const paddle::Tensor &fc1_out_tensor, @@ -707,6 +709,22 @@ void SpeculateSaveWithOutputMsgStatic(const paddle::Tensor& accept_tokens, void SpeculateClearAcceptNums(const paddle::Tensor& accept_num, const paddle::Tensor& seq_lens_decoder); +void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &step_draft_tokens, + const paddle::Tensor &step_seq_lens_this_time, + const paddle::Tensor &accept_num, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &stop_nums, + const int block_size, + const int max_draft_tokens); + void NgramMatch(const paddle::Tensor &input_ids, const paddle::Tensor &input_ids_len, const paddle::Tensor &pre_ids, @@ -750,6 +768,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, @@ -763,7 +782,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& base_model_draft_tokens, const int max_draft_token, const bool truncate_first_token, - const bool splitwise_prefill); + const bool splitwise_prefill, + const bool kvcache_scheduler_v1); void DraftModelUpdate(const paddle::Tensor& inter_next_tokens, @@ -1228,6 +1248,8 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("speculate_clear_accept_nums",&SpeculateClearAcceptNums, "speculate_clear_accept_nums function"); + m.def("speculate_schedule_cache",&SpeculateScheduleCache, "SpeculateScheduleCache function"); + m.def("ngram_match", &NgramMatch, "ngram_match function"); m.def("hybird_mtp_ngram", &HybridMtpNgram, "ngram_match_mixed function"); diff --git a/custom_ops/gpu_ops/recover_decode_task.cu b/custom_ops/gpu_ops/recover_decode_task.cu index 88c7dd51c..ae4e77ad6 100644 --- a/custom_ops/gpu_ops/recover_decode_task.cu +++ b/custom_ops/gpu_ops/recover_decode_task.cu @@ -15,31 +15,72 @@ #include "helper.h" __global__ void recover_decode_task(bool *stop_flags, - int *seq_lens_this_time, - int *seq_lens_encoder, - int *seq_lens_decoder, - int *step_seq_lens_decoder, - int *block_tables, - bool *is_block_step, - const int bsz, - const int block_num_per_seq, - const int block_size) { + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + const int bsz, + const int block_num_per_seq, + const int block_size) { int thread_idx = threadIdx.x; if (thread_idx < bsz) { if(is_block_step[thread_idx] == true) { int *block_table_now = block_tables + thread_idx * block_num_per_seq; if (block_table_now[step_seq_lens_decoder[thread_idx] / block_size] != -1) { - // can be recovered for decoding - is_block_step[thread_idx] = false; - seq_lens_this_time[thread_idx]= 1; - stop_flags[thread_idx] = false; - seq_lens_encoder[thread_idx] = 0; - seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; - } + // can be recovered for decoding + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx]= 1; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + + } } } } +__global__ void recover_spec_decode_task(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int *block_tables, + bool *is_block_step, + int64_t *draft_tokens, + const int64_t *step_draft_tokens, + const int *step_seq_lens_this_time, + const int bsz, + const int block_num_per_seq, + const int block_size, + const int draft_tokens_len, + const int num_extra_tokens) { + int thread_idx = threadIdx.x; + if (thread_idx < bsz) { + if(is_block_step[thread_idx] == true) { + int *block_table_now = block_tables + thread_idx * block_num_per_seq; + int max_possible_block_idx = (step_seq_lens_decoder[thread_idx] + num_extra_tokens) / block_size; + max_possible_block_idx = min(max_possible_block_idx, block_num_per_seq); + if (block_table_now[max_possible_block_idx] != -1) { + // can be recovered for decoding + int64_t *draft_tokens_now = draft_tokens + thread_idx * draft_tokens_len; + const int64_t *step_draft_tokens_now = step_draft_tokens + thread_idx * draft_tokens_len; + is_block_step[thread_idx] = false; + seq_lens_this_time[thread_idx] = step_seq_lens_this_time[thread_idx]; + stop_flags[thread_idx] = false; + seq_lens_encoder[thread_idx] = 0; + seq_lens_decoder[thread_idx] = step_seq_lens_decoder[thread_idx]; + for (int i = 0; i < seq_lens_this_time[thread_idx]; i++) { + draft_tokens_now[i] = step_draft_tokens_now[i]; + } + + } + } + } +} + + void RecoverDecodeTask(const paddle::Tensor &stop_flags, const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_encoder, @@ -47,7 +88,11 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags, const paddle::Tensor &step_seq_lens_decoder, const paddle::Tensor &block_tables, const paddle::Tensor &is_block_step, - const int block_size) { + const paddle::optional &draft_tokens, + const paddle::optional &step_draft_tokens, + const paddle::optional &step_seq_lens_this_time, + const int block_size, + const int max_draft_tokens) { #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_ctx = static_cast(paddle::experimental::DeviceContextPool::Instance().Get(seq_lens_this_time.place())); auto cu_stream = dev_ctx->stream(); @@ -56,17 +101,38 @@ void RecoverDecodeTask(const paddle::Tensor &stop_flags, #endif const int bsz = seq_lens_this_time.shape()[0]; const int block_num_per_seq = block_tables.shape()[1]; - recover_decode_task<<<1, 1024, 0, cu_stream>>>( - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(is_block_step.data()), - bsz, - block_num_per_seq, - block_size); + if (draft_tokens) { + const int draft_tokens_len = draft_tokens.get_ptr()->shape()[1]; + recover_spec_decode_task<<<1, 1024, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + const_cast(draft_tokens.get_ptr()->data()), + step_draft_tokens.get_ptr()->data(), + step_seq_lens_this_time.get_ptr()->data(), + bsz, + block_num_per_seq, + block_size, + draft_tokens_len, + max_draft_tokens * 2 + 1); + + } else { + recover_decode_task<<<1, 1024, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(is_block_step.data()), + bsz, + block_num_per_seq, + block_size); + } } PD_BUILD_STATIC_OP(recover_decode_task) @@ -76,8 +142,11 @@ PD_BUILD_STATIC_OP(recover_decode_task) "seq_lens_decoder", "step_seq_lens_decoder", "block_tables", - "is_block_step"}) - .Attrs({"block_size: int"}) + "is_block_step", + paddle::Optional("draft_tokens"), + paddle::Optional("step_draft_tokens"), + paddle::Optional("step_seq_lens_this_time")}) + .Attrs({"block_size: int", "max_draft_tokens: int"}) .Outputs({"seq_lens_this_time_out", "seq_lens_encoder_out", "seq_lens_decoder_out", diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu index 573f6fb68..051d20a03 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/draft_model_preprocess.cu @@ -15,7 +15,48 @@ #include "helper.h" #include "paddle/extension.h" -template + +#define DISPATCH_BLOCKSIZE(BLOCK_SIZE, ...) \ + do { \ + constexpr int BlockSize = BLOCK_SIZE; \ + __VA_ARGS__; \ + } while (0) + +#define DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, ...) \ + do { \ + if (truncate_first_token) { \ + constexpr bool TRUNCATE_FIRST_TOKEN = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool TRUNCATE_FIRST_TOKEN = false; \ + __VA_ARGS__; \ + } \ + } while (0) + +#define DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, ...) \ + do { \ + if (kvcache_scheduler_v1) { \ + constexpr bool KVCACHE_SCHEDULER_V1 = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool KVCACHE_SCHEDULER_V1 = false; \ + __VA_ARGS__; \ + } \ + } while (0) + +#define DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, ...) \ + do { \ + if (splitwise_prefill) { \ + constexpr bool SPLITWISE_PREFILL = true; \ + __VA_ARGS__; \ + } else { \ + constexpr bool SPLITWISE_PREFILL = false; \ + __VA_ARGS__; \ + } \ + } while (0) + + +template __global__ void process_splitwise_prefill( int64_t* draft_tokens, int64_t* input_ids, @@ -25,6 +66,7 @@ __global__ void process_splitwise_prefill( int* seq_lens_decoder, int64_t* step_idx, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, int64_t* pre_ids, const int64_t* accept_tokens, @@ -58,7 +100,7 @@ __global__ void process_splitwise_prefill( stop_flags[tid] = false; int64_t base_model_first_token = accept_tokens_now[0]; int position = seq_len_encoder; - if (TRCUNCATE_FIRST_TOKEN) { + if (TRUNCATE_FIRST_TOKEN) { input_ids_now[position - 1] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder; } else { @@ -84,7 +126,7 @@ __global__ void process_splitwise_prefill( -template +template __global__ void draft_model_preprocess_kernel( int64_t* draft_tokens, int64_t* input_ids, @@ -94,6 +136,7 @@ __global__ void draft_model_preprocess_kernel( int* seq_lens_decoder, int64_t* step_idx, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, int64_t* pre_ids, const int64_t* accept_tokens, @@ -134,14 +177,26 @@ __global__ void draft_model_preprocess_kernel( base_model_draft_tokens_now[i] = -1; } - if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { - batch_drop[tid] = true; - stop_flags[tid] = true; + // 1. process block_step situation + // -- In v0 mode, block_step will drop mtp query. + // -- In v1 mode, block_step will continue to infer. + if constexpr(KVCACHE_SCHEDULER_V1) { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + stop_flags[tid] = true; + is_block_step[tid] = true; + // Need to continue infer + } + } else { + if (base_model_stop_flags[tid] && base_model_is_block_step[tid]) { + batch_drop[tid] = true; + stop_flags[tid] = true; + } } + // 2. process normal query, not in any special case. if (!(base_model_stop_flags[tid] || batch_drop[tid])) { not_stop_flag = 1; - // 1. first token + // prefill generation if (seq_lens_encoder[tid] > 0) { // Can be extended to first few tokens int seq_len_encoder = seq_lens_encoder[tid]; @@ -149,14 +204,20 @@ __global__ void draft_model_preprocess_kernel( int64_t base_model_first_token = accept_tokens_now[0]; pre_ids_now[0] = base_model_first_token; int position = seq_len_encoder; - if (TRCUNCATE_FIRST_TOKEN) { + if (TRUNCATE_FIRST_TOKEN) { input_ids_now[position - 1] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder; } else { input_ids_now[position] = base_model_first_token; seq_lens_this_time[tid] = seq_len_encoder + 1; } - } else { + } else { // decode generation + if constexpr (KVCACHE_SCHEDULER_V1) { + // 3. try to recover mtp infer in V1 mode + if (!base_model_is_block_step[tid] && is_block_step[tid]) { + is_block_step[tid] = false; + } + } if (stop_flags[tid]) { stop_flags[tid] = false; // TODO: check @@ -189,99 +250,8 @@ __global__ void draft_model_preprocess_kernel( } } -template -void DispatchRunner( - const cudaStream_t& stream, - int64_t* draft_tokens, - int64_t* input_ids, - bool* stop_flags, - int* seq_lens_this_time, - int* seq_lens_encoder, - int* seq_lens_decoder, - int64_t* step_idx, - bool* not_need_stop, - bool* batch_drop, - int64_t* pre_ids, - const int64_t* accept_tokens, - const int* accept_num, - const int* base_model_seq_lens_this_time, - const int* base_model_seq_lens_encoder, - const int* base_model_seq_lens_decoder, - const int64_t* base_model_step_idx, - const bool* base_model_stop_flags, - const bool* base_model_is_block_step, - int64_t* base_model_draft_tokens, - const int bsz, - const int num_model_step, - const int accept_tokens_len, - const int draft_tokens_len, - const int input_ids_len, - const int base_model_draft_tokens_len, - const int pre_ids_len, - const bool splitwise_prefill) { - constexpr int BlockSize = 512; - if (splitwise_prefill) { - process_splitwise_prefill - <<<1, BlockSize, 0, stream>>>( - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len); - } else { - draft_model_preprocess_kernel - <<<1, BlockSize, 0, stream>>>( - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len); - } -} -void DispatchTokenMode( +void DispatchRunner( const cudaStream_t &stream, int64_t* draft_tokens, int64_t* input_ids, @@ -291,6 +261,7 @@ void DispatchTokenMode( int* seq_lens_decoder, int64_t* step_idx, bool* not_need_stop, + bool* is_block_step, bool* batch_drop, int64_t* pre_ids, const int64_t* accept_tokens, @@ -310,75 +281,79 @@ void DispatchTokenMode( const int base_model_draft_tokens_len, const int pre_ids_len, const bool truncate_first_token, - const bool splitwise_prefill) { - if (truncate_first_token) { - DispatchRunner( - stream, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - splitwise_prefill - ); - } else { - DispatchRunner( - stream, - draft_tokens, - input_ids, - stop_flags, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - step_idx, - not_need_stop, - batch_drop, - pre_ids, - accept_tokens, - accept_num, - base_model_seq_lens_this_time, - base_model_seq_lens_encoder, - base_model_seq_lens_decoder, - base_model_step_idx, - base_model_stop_flags, - base_model_is_block_step, - base_model_draft_tokens, - bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - splitwise_prefill - ); - } + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { + DISPATCH_BLOCKSIZE(512, { + DISPATCH_TRUNCATE_FIRST_TOKEN(truncate_first_token, TRUNCATE_FIRST_TOKEN, { + DISPATCH_KVCACHE_SCHEDULER(kvcache_scheduler_v1, KVCACHE_SCHEDULER_V1, { + DISPATCH_SPLITWISE_PREFILL(splitwise_prefill, SPLITWISE_PREFILL, { + if constexpr (SPLITWISE_PREFILL) { + process_splitwise_prefill + <<<1, BlockSize, 0, stream>>>( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len); + } else { + draft_model_preprocess_kernel + <<<1, BlockSize, 0, stream>>>( + draft_tokens, + input_ids, + stop_flags, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + step_idx, + not_need_stop, + is_block_step, + batch_drop, + pre_ids, + accept_tokens, + accept_num, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + base_model_seq_lens_decoder, + base_model_step_idx, + base_model_stop_flags, + base_model_is_block_step, + base_model_draft_tokens, + bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len); + } + }); + }); + }); + }); } - - - void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& input_ids, const paddle::Tensor& stop_flags, @@ -387,6 +362,7 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& step_idx, const paddle::Tensor& not_need_stop, + const paddle::Tensor& is_block_step, const paddle::Tensor& batch_drop, const paddle::Tensor& pre_ids, const paddle::Tensor& accept_tokens, @@ -400,7 +376,8 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, const paddle::Tensor& base_model_draft_tokens, const int num_model_step, const bool truncate_first_token, - const bool splitwise_prefill) { + const bool splitwise_prefill, + const bool kvcache_scheduler_v1) { int real_bsz = seq_lens_this_time.shape()[0]; int accept_tokens_len = accept_tokens.shape()[1]; int input_ids_len = input_ids.shape()[1]; @@ -412,36 +389,38 @@ void DraftModelPreprocess(const paddle::Tensor& draft_tokens, auto not_need_stop_gpu = not_need_stop.copy_to(seq_lens_this_time.place(), false); - DispatchTokenMode( - cu_stream, - const_cast(draft_tokens.data()), - const_cast(input_ids.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(step_idx.data()), - const_cast(not_need_stop_gpu.data()), - const_cast(batch_drop.data()), - const_cast(pre_ids.data()), - accept_tokens.data(), - accept_num.data(), - base_model_seq_lens_this_time.data(), - base_model_seq_lens_encoder.data(), - base_model_seq_lens_decoder.data(), - base_model_step_idx.data(), - base_model_stop_flags.data(), - base_model_is_block_step.data(), - const_cast(base_model_draft_tokens.data()), - real_bsz, - num_model_step, - accept_tokens_len, - draft_tokens_len, - input_ids_len, - base_model_draft_tokens_len, - pre_ids_len, - truncate_first_token, - splitwise_prefill); + DispatchRunner( + cu_stream, + const_cast(draft_tokens.data()), + const_cast(input_ids.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_idx.data()), + const_cast(not_need_stop_gpu.data()), + const_cast(is_block_step.data()), + const_cast(batch_drop.data()), + const_cast(pre_ids.data()), + accept_tokens.data(), + accept_num.data(), + base_model_seq_lens_this_time.data(), + base_model_seq_lens_encoder.data(), + base_model_seq_lens_decoder.data(), + base_model_step_idx.data(), + base_model_stop_flags.data(), + base_model_is_block_step.data(), + const_cast(base_model_draft_tokens.data()), + real_bsz, + num_model_step, + accept_tokens_len, + draft_tokens_len, + input_ids_len, + base_model_draft_tokens_len, + pre_ids_len, + truncate_first_token, + splitwise_prefill, + kvcache_scheduler_v1); auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); @@ -459,6 +438,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "seq_lens_decoder", "step_idx", "not_need_stop", + "is_block_step", "batch_drop", "pre_ids", "accept_tokens", @@ -480,7 +460,7 @@ PD_BUILD_STATIC_OP(draft_model_preprocess) "not_need_stop_out", "batch_drop_out", "pre_ids_out"}) - .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool"}) + .Attrs({"num_model_step: int", "truncate_first_token: bool", "splitwise_prefill: bool", "kvcache_scheduler_v1: bool"}) .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, {"input_ids", "input_ids_out"}, {"stop_flags", "stop_flags_out"}, diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu new file mode 100644 index 000000000..633c5bb4d --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_schedule_cache.cu @@ -0,0 +1,176 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +template +__global__ void speculate_schedula_cache( + const int64_t *draft_tokens, + int *block_tables, + bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *step_seq_lens_decoder, + int64_t *step_draft_tokens, + int *step_seq_lens_this_time, + int *accept_num, + int64_t *accept_tokens, + bool *is_block_step, + bool *not_need_stop, + const int64_t *stop_nums, + const int real_bsz, + const int max_bsz, + const int max_next_step_tokens, + const int draft_tokens_len, + const int accept_tokens_len, + const int block_size, + const int block_num_per_seq) { + const int bid = threadIdx.x; + int stop_flag_now_int = 0; + if (bid < real_bsz) { + if (!stop_flags[bid]) { + const int64_t *draft_tokens_now = draft_tokens + bid * draft_tokens_len; + int64_t *step_draft_tokens_now = step_draft_tokens + bid * draft_tokens_len; + int *block_table_now = block_tables + bid * block_num_per_seq; + int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len; + const int max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) / block_size; + if (max_possible_block_idx < block_num_per_seq && block_table_now[max_possible_block_idx] == -1) { + is_block_step[bid] = true; + step_seq_lens_this_time[bid] = seq_lens_this_time[bid]; + seq_lens_this_time[bid] = 0; + stop_flags[bid] = true; + stop_flag_now_int = 1; + step_seq_lens_decoder[bid] = seq_lens_decoder[bid]; + seq_lens_decoder[bid] = 0; + accept_num[bid] = 0; + for (int i = 0; i < accept_tokens_len; i++) { + accept_tokens_now[i] = -1; + } + for (int i = 0; i < draft_tokens_len; i++) { + step_draft_tokens_now[i] = draft_tokens_now[i]; + } + } + } else { + stop_flag_now_int = 1; + } + } else if (bid >= real_bsz && bid < max_bsz) { + stop_flag_now_int = 1; + } + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // printf("stop_flag_now_int %d \n", stop_flag_now_int); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + + if (threadIdx.x == 0) { + // printf("stop_sum %d \n", stop_sum); + not_need_stop[0] = stop_sum < stop_nums[0]; + } +} + +void SpeculateScheduleCache(const paddle::Tensor &draft_tokens, + const paddle::Tensor &block_tables, + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &step_seq_lens_decoder, + const paddle::Tensor &step_draft_tokens, + const paddle::Tensor &step_seq_lens_this_time, + const paddle::Tensor &accept_num, + const paddle::Tensor &accept_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor ¬_need_stop, + const paddle::Tensor &stop_nums, + const int block_size, + const int max_draft_tokens) { + const int real_bsz = seq_lens_this_time.shape()[0]; + const int max_bsz = stop_flags.shape()[0]; + const int accept_tokens_len = accept_tokens.shape()[1]; + const int draft_token_len = draft_tokens.shape()[1]; + const int block_num_per_seq = block_tables.shape()[1]; + + constexpr int BlockSize = 512; + const int max_next_step_tokens = 2 * max_draft_tokens + 2; + + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + speculate_schedula_cache<<<1, BlockSize, 0, seq_lens_this_time.stream()>>>( + draft_tokens.data(), + const_cast(block_tables.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(step_seq_lens_decoder.data()), + const_cast(step_draft_tokens.data()), + const_cast(step_seq_lens_this_time.data()), + const_cast(accept_num.data()), + const_cast(accept_tokens.data()), + const_cast(is_block_step.data()), + const_cast(not_need_stop_gpu.data()), + stop_nums.data(), + real_bsz, + max_bsz, + max_next_step_tokens, + draft_token_len, + accept_tokens_len, + block_size, + block_num_per_seq + ); + + auto not_need_stop_cpu = + not_need_stop_gpu.copy_to(not_need_stop.place(), true); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_STATIC_OP(speculate_schedule_cache) + .Inputs({"draft_tokens", + "block_tables", + "stop_flags", + "seq_lens_this_time", + "seq_lens_decoder", + "step_seq_lens_decoder", + "step_draft_tokens", + "step_seq_lens_this_time", + "accept_num", + "accept_tokens", + "is_block_step", + "not_need_stop", + "stop_nums"}) + .Attrs({"block_size: int", "max_draft_tokens: int"}) + .Outputs({"draft_tokens_out", + "block_tables_out", + "stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_decoder_out", + "step_seq_lens_decoder_out", + "step_draft_tokens_out", + "step_seq_lens_this_time_out", + "accept_num_out", + "accept_tokens_out", + "is_block_step_out", + "not_need_stop_out"}) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"block_tables", "block_tables_out"}, + {"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"step_seq_lens_decoder", "step_seq_lens_decoder_out"}, + {"step_draft_tokens", "step_draft_tokens_out"}, + {"step_seq_lens_this_time", "step_seq_lens_this_time_out"}, + {"accept_num", "accept_num_out"}, + {"accept_tokens", "accept_tokens_out"}, + {"is_block_step", "is_block_step_out"}, + {"not_need_stop", "not_need_stop_out"},}) + .SetKernelFn(PD_KERNEL(SpeculateScheduleCache)); diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 03a1a8b19..299231274 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -891,7 +891,7 @@ class CacheConfig: else: self.kv_cache_ratio = 0.75 self.enc_dec_block_num = 0 if current_platform.is_iluvatar() else 2 - self.prealloc_dec_block_slot_num_threshold = 5 + self.prealloc_dec_block_slot_num_threshold = 12 self.cache_dtype = "bfloat16" self.model_cfg = None self.enable_chunked_prefill = False diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index b3a8b7e56..664b2b36d 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -162,8 +162,7 @@ class EngineArgs: """ Ratio of tokens to process in a block. """ - - prealloc_dec_block_slot_num_threshold: int = 5 + prealloc_dec_block_slot_num_threshold: int = 12 """ Token slot threshold for preallocating decoder blocks. """ @@ -693,7 +692,7 @@ class EngineArgs: cache_group.add_argument( "--prealloc-dec-block-slot-num-threshold", type=int, - default=5, + default=12, help="Number of token slot threadshold to allocate next blocks for decoding.", ) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 8db1cca57..5ea7f094a 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -99,10 +99,14 @@ class ResourceManagerV1(ResourceManager): def get_new_block_nums(self, request: Request, num_new_tokens: int): self.check_and_free_block_tables() - return ( + block_num = ( request.num_computed_tokens + num_new_tokens + self.config.cache_config.block_size - 1 ) // self.config.cache_config.block_size - len(request.block_tables) + if self.config.speculative_config.method is not None: + block_num = min(block_num + 1, self.config.cache_config.max_block_num_per_seq) + return block_num + def _prepare_prefill_task(self, request, new_token_num): request.prefill_start_index = request.num_computed_tokens request.prefill_end_index = request.num_computed_tokens + new_token_num diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index b088c7a13..026a3af62 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -331,7 +331,9 @@ def post_process_normal( ) -def post_process_specualate(model_output, save_each_rank: bool = False, skip_save_output: bool = False): +def post_process_specualate( + model_output: ModelOutputData, save_each_rank: bool = False, skip_save_output: bool = False +): """""" speculate_update( model_output.seq_lens_encoder, diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 481dbc13e..1397c79bf 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -19,8 +19,10 @@ from typing import List import numpy as np import paddle +from paddleformers.utils.log import logger -from fastdeploy.engine.request import Request +from fastdeploy import envs +from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -50,14 +52,14 @@ class MTPProposer(Proposer): Proposer for Multi-Token-Prediction(MTP) """ - def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs): + def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs): super().__init__(cfg) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id self._update_cfg(main_model) self._load_model() - self.main_model_inputs = main_model_inputs + self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps @@ -199,14 +201,16 @@ class MTPProposer(Proposer): encoder_block_shape_q = 64 decoder_block_shape_q = 16 - self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"]) + self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( - self.main_model_inputs["decoder_tile_ids_per_batch"] + self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( - self.main_model_inputs["decoder_num_blocks_cpu"] + self.target_model_inputs["decoder_num_blocks_cpu"] ).pin_memory() - self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu() + self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( + self.target_model_inputs["max_len_tensor_cpu"] + ).cpu() # Get the attention backend attn_cls = get_attention_backend() @@ -265,24 +269,24 @@ class MTPProposer(Proposer): """ self.model_inputs = {} # Same shape/dytpe with base model - self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"]) - self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"]) - self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"]) + self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"]) + self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"]) + self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"]) - self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"]) - self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"]) - self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"]) - self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"]) - self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"]) + self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) + self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"]) + self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"]) + self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"]) + self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") - self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"]) - self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"]) - self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"]) - self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"]) - self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"]) - self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"]) + self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) + self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) + self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) + self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) + self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"]) + self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( - self.main_model_inputs["decoder_tile_ids_per_batch"] + self.target_model_inputs["decoder_tile_ids_per_batch"] ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) @@ -294,22 +298,22 @@ class MTPProposer(Proposer): ) # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency - self.model_inputs["top_p"] = self.main_model_inputs["top_p"] - self.model_inputs["top_k"] = self.main_model_inputs["top_k"] - self.model_inputs["temperature"] = self.main_model_inputs["temperature"] - self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"] - self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"] - self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"] - self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"] - self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"] + self.model_inputs["top_p"] = self.target_model_inputs["top_p"] + self.model_inputs["top_k"] = self.target_model_inputs["top_k"] + self.model_inputs["temperature"] = self.target_model_inputs["temperature"] + self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"] + self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"] + self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"] + self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"] + self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"] - self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"] - self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"] + self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"] + self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"] - self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"] + self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"] # Integrate the updated results in model forward - self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"] + self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"] self.model_inputs["substep"] = 0 # Declare AttentionBackend buffers @@ -323,7 +327,7 @@ class MTPProposer(Proposer): shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64" ) - self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"]) + self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"]) self.free_list = list( range( @@ -337,14 +341,77 @@ class MTPProposer(Proposer): self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32") self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32") + self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.full_like( - self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" + self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): + + if "caches" not in self.model_inputs: + self.initialize_kv_cache() + req_len = len(req_dicts) + # has_prefill_task = False + # has_decode_task = False + for i in range(req_len): + request = req_dicts[i] + logger.info(f"{i}th request-{request.request_id}: {request}") + idx = request.idx + if request.task_type.value == RequestType.PREFILL.value: # prefill task + prefill_start_index = request.prefill_start_index + prefill_end_index = request.prefill_end_index + length = prefill_end_index - prefill_start_index + + input_ids = request.prompt_token_ids + request.output_token_ids + + self.input_ids_len[idx] = length + self.model_inputs["pre_ids"][idx : idx + 1] = -1 + self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ + idx : idx + 1, 1:length + ] + encoder_block_num = len(request.block_tables) + self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + self.model_inputs["stop_flags"][idx : idx + 1] = False + self.model_inputs["batch_drop"][idx : idx + 1] = False + + self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index + self.seq_lens_this_time_buffer[idx : idx + 1] = length + self.model_inputs["step_idx"][idx : idx + 1] = ( + len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 + ) + + # has_prefill_task = True + elif request.task_type.value == RequestType.DECODE.value: # decode task + encoder_block_num = len(request.block_tables) + self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num + self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( + request.block_tables, dtype="int32" + ) + # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode + # has_decode_task = True + # continue + else: + self.model_inputs["block_tables"][idx : idx + 1, :] = -1 + self.model_inputs["stop_flags"][idx : idx + 1] = True + self.seq_lens_this_time_buffer[idx : idx + 1] = 0 + self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 + self.model_inputs["is_block_step"][idx : idx + 1] = False + continue + # if has_prefill_task or has_decode_task: + # self.model_inputs["not_need_stop"][0] = True + self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """ Process inputs for prefill tasks and insert it to model_inputs buffer @@ -397,9 +464,9 @@ class MTPProposer(Proposer): length = len(request.prompt_token_ids) if length > 1: - self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][ - idx : idx + 1, 1:length - ] + self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ + "input_ids" + ][idx : idx + 1, 1:length] self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["step_idx"][idx : idx + 1] = 0 if self.cache_config.enable_chunked_prefill: @@ -455,6 +522,7 @@ class MTPProposer(Proposer): """ Prepare MTP inputs """ + use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], @@ -465,19 +533,21 @@ class MTPProposer(Proposer): self.model_inputs["step_idx"], self.model_inputs["not_need_stop"], self.model_inputs["batch_drop"], + self.model_inputs["is_block_step"], self.model_inputs["pre_ids"], - self.main_model_inputs["accept_tokens"], - self.main_model_inputs["accept_num"], - self.main_model_inputs["seq_lens_this_time"], - self.main_model_inputs["seq_lens_encoder"], - self.main_model_inputs["seq_lens_decoder"], - self.main_model_inputs["step_idx"], - self.main_model_inputs["stop_flags"], - self.main_model_inputs["is_block_step"], - self.main_model_inputs["draft_tokens"], + self.target_model_inputs["accept_tokens"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["seq_lens_decoder"], + self.target_model_inputs["step_idx"], + self.target_model_inputs["stop_flags"], + self.target_model_inputs["is_block_step"], + self.target_model_inputs["draft_tokens"], self.num_model_steps, self.speculative_method in ["eagle", "mtp"], self.role == "prefill", + use_v1_cache_scheduler, ) target_hidden_states = eagle_get_hidden_states( @@ -486,9 +556,9 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["stop_flags"], - self.main_model_inputs["accept_num"], - self.main_model_inputs["seq_lens_this_time"], - self.main_model_inputs["seq_lens_encoder"], + self.target_model_inputs["accept_num"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) if isinstance(target_hidden_states, list): @@ -658,41 +728,41 @@ class MTPProposer(Proposer): Allocate/Free block of MPT. """ draft_model_postprocess( - self.main_model_inputs["draft_tokens"], - self.main_model_inputs["seq_lens_this_time"], - self.main_model_inputs["seq_lens_encoder"], - self.main_model_inputs["stop_flags"], - ) - - mtp_step_paddle( - self.main_model_inputs["stop_flags"], - self.model_inputs["stop_flags"], - self.model_inputs["batch_drop"], - self.model_inputs["seq_lens_this_time"], - self.model_inputs["seq_lens_encoder"], - self.model_inputs["seq_lens_decoder"], - self.model_inputs["block_tables"], - self.model_inputs["encoder_block_lens"], - self.model_inputs["used_list_len"], - self.model_inputs["free_list"], - self.model_inputs["free_list_len"], - self.cache_config.block_size, - self.max_draft_token_num, + self.target_model_inputs["draft_tokens"], + self.target_model_inputs["seq_lens_this_time"], + self.target_model_inputs["seq_lens_encoder"], + self.target_model_inputs["stop_flags"], ) + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: + mtp_step_paddle( + self.target_model_inputs["stop_flags"], + self.model_inputs["stop_flags"], + self.model_inputs["batch_drop"], + self.model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_encoder"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["block_tables"], + self.model_inputs["encoder_block_lens"], + self.model_inputs["used_list_len"], + self.model_inputs["free_list"], + self.model_inputs["free_list_len"], + self.cache_config.block_size, + self.max_draft_token_num, + ) def _extend_draft_token_with_ngram_match(self): # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency device = paddle.CUDAPinnedPlace() - draft_tokens = self.main_model_inputs["draft_tokens"].cpu() - seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu() + draft_tokens = self.target_model_inputs["draft_tokens"].cpu() + seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() hybrid_mtp_ngram( self.model_inputs["input_ids"]._copy_to(device, True), self.input_ids_len, self.model_inputs["pre_ids"]._copy_to(device, True), self.model_inputs["step_idx"].cpu(), - self.main_model_inputs["actual_draft_token_num"].cpu(), + self.target_model_inputs["actual_draft_token_num"].cpu(), draft_tokens, seq_lens_this_time, seq_lens_decoder, @@ -701,8 +771,8 @@ class MTPProposer(Proposer): self.min_ngram_size, self.max_draft_token_num, ) - self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda() - self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() + self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() + self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() def _run_impl(self, full_hidden_states): """""" diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index d89990c53..2b85123ac 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -58,6 +58,7 @@ else: recover_decode_task, set_value_by_flags_and_idx, share_external_data, + speculate_schedule_cache, ) from fastdeploy.model_executor.pre_and_post_process import ( @@ -394,6 +395,8 @@ class GPUModelRunner(ModelRunnerBase): if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + if self.speculative_method in ["mtp"]: + self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None): """ @@ -815,6 +818,13 @@ class GPUModelRunner(ModelRunnerBase): fill_value=0, dtype="int32", ) + # For V1_KVCACHE_SCHEDULER + self.share_inputs["step_draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], + fill_value=0, + dtype="int64", + ) + self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") if self.enable_mm: head_dim = self.model_config.head_dim @@ -853,7 +863,11 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["step_seq_lens_decoder"], self.share_inputs["block_tables"], self.share_inputs["is_block_step"], + self.share_inputs["draft_tokens"] if self.speculative_decoding else None, + self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None, + self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None, self.cache_config.block_size, + self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0, ) # Remove padding @@ -1556,6 +1570,24 @@ class GPUModelRunner(ModelRunnerBase): self._update_chunked_prefill(model_forward_batch) self._add_cache(model_forward_batch) + elif self.speculative_decoding: + speculate_schedule_cache( + self.share_inputs["draft_tokens"], + self.share_inputs["block_tables"], + self.share_inputs["stop_flags"], + self.share_inputs["seq_lens_this_time"], + self.share_inputs["seq_lens_decoder"], + self.share_inputs["step_seq_lens_decoder"], + self.share_inputs["step_draft_tokens"], + self.share_inputs["step_seq_lens_this_time"], + self.share_inputs["accept_num"], + self.share_inputs["accept_tokens"], + self.share_inputs["is_block_step"], + self.share_inputs["not_need_stop"], + self.share_inputs["stop_nums"], + self.cache_config.block_size, + self.speculative_config.num_speculative_tokens, + ) self.seq_lens_this_time_buffer[:num_running_requests].copy_( self.share_inputs["seq_lens_this_time"][:num_running_requests], False diff --git a/tests/operators/test_speculative_schedule_cache.py b/tests/operators/test_speculative_schedule_cache.py new file mode 100644 index 000000000..9c95ad203 --- /dev/null +++ b/tests/operators/test_speculative_schedule_cache.py @@ -0,0 +1,239 @@ +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache + + +def cpu_reference( + draft_tokens, + block_tables, + stop_flags, + seq_lens_this_time, + seq_lens_decoder, + step_seq_lens_decoder, + step_draft_tokens, + step_seq_lens_this_time, + accept_num, + accept_tokens, + is_block_step, + not_need_stop, + stop_nums, + block_size, + max_draft_tokens, +): + """Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads). + + Shapes are the same as inputs to the custom op. This mutates the provided + NumPy arrays in-place, exactly like the kernel does. + """ + real_bsz = seq_lens_this_time.shape[0] + max_bsz = stop_flags.shape[0] + draft_tokens_len = draft_tokens.shape[1] + block_num_per_seq = block_tables.shape[1] + + max_next_step_tokens = 2 * max_draft_tokens + 2 + + # Block-local reduction input per thread (threadIdx.x -> bid) + stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512 + + for bid in range(512): + if bid < real_bsz: + if not stop_flags[bid]: + max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size + if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1: + is_block_step[bid] = True + step_seq_lens_this_time[bid] = seq_lens_this_time[bid] + seq_lens_this_time[bid] = 0 + stop_flags[bid] = True + step_seq_lens_decoder[bid] = seq_lens_decoder[bid] + seq_lens_decoder[bid] = 0 + accept_num[bid] = 0 + accept_tokens[bid, :] = -1 + step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len] + stop_flag_now_int[bid] = 1 + else: + stop_flag_now_int[bid] = 0 + else: + stop_flag_now_int[bid] = 1 + elif bid < max_bsz: + # Threads in [real_bsz, max_bsz) contribute 1 to reduction + stop_flag_now_int[bid] = 1 + else: + stop_flag_now_int[bid] = 0 + + stop_sum = int(stop_flag_now_int.sum()) + not_need_stop[0] = stop_sum < int(stop_nums[0]) + + +class TestSpeculateScheduleCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not paddle.is_compiled_with_cuda(): + raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.") + paddle.device.set_device("gpu") + + def setUp(self): + # --- Construct a deterministic case that exercises all branches --- + # real_bsz < max_bsz to test the padding logic in the CUB reduction + self.real_bsz = 3 + self.max_bsz = 5 # only stop_flags has length max_bsz + + self.draft_tokens_len = 6 + self.accept_tokens_len = 5 + self.block_size = 4 + self.block_num_per_seq = 3 + self.max_draft_tokens = 2 # -> max_next_step_tokens = 6 + + # Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped + # seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2 + self.draft_tokens = paddle.to_tensor( + np.array( + [ + [1, 1, 1, 1, 1, 1], + [2, 2, 2, 2, 2, 2], + [3, 3, 3, 3, 3, 3], + ], + dtype=np.int64, + ) + ) + self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32)) + # stop_flags length is max_bsz, others are real_bsz + self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_)) + self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32)) + self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32)) + + # Will be filled by kernel for the triggering bids only + self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32") + self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64") + self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32") + + # Intentionally non-zero so we can verify in-place zeroing only where triggered + self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32)) + self.accept_tokens = paddle.to_tensor( + np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape( + self.real_bsz, self.accept_tokens_len + ) + ) + self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool) + + # not_need_stop lives on CPU in the caller; the kernel copies to device internally + self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu() + + # Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4 + # Set stop_nums to 5 so not_need_stop = (4 < 5) = True + self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64) + + # Keep NumPy copies for CPU reference + self.np_draft_tokens = self.draft_tokens.numpy().copy() + self.np_block_tables = self.block_tables.numpy().copy() + self.np_stop_flags = self.stop_flags.numpy().copy() + self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy() + self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy() + self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy() + self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy() + self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy() + self.np_accept_num = self.accept_num.numpy().copy() + self.np_accept_tokens = self.accept_tokens.numpy().copy() + self.np_is_block_step = self.is_block_step.numpy().copy() + self.np_not_need_stop = self.not_need_stop.numpy().copy() + self.np_stop_nums = self.stop_nums.numpy().copy() + + def test_correctness_against_cpu_reference(self): + # Run GPU kernel (in-place) + speculate_schedule_cache( + self.draft_tokens, + self.block_tables, + self.stop_flags, + self.seq_lens_this_time, + self.seq_lens_decoder, + self.step_seq_lens_decoder, + self.step_draft_tokens, + self.step_seq_lens_this_time, + self.accept_num, + self.accept_tokens, + self.is_block_step, + self.not_need_stop, + self.stop_nums, + self.block_size, + self.max_draft_tokens, + ) + + # Compute CPU reference (in-place on NumPy copies) + cpu_reference( + self.np_draft_tokens, + self.np_block_tables, + self.np_stop_flags, + self.np_seq_lens_this_time, + self.np_seq_lens_decoder, + self.np_step_seq_lens_decoder, + self.np_step_draft_tokens, + self.np_step_seq_lens_this_time, + self.np_accept_num, + self.np_accept_tokens, + self.np_is_block_step, + self.np_not_need_stop, + self.np_stop_nums, + self.block_size, + self.max_draft_tokens, + ) + + # Compare all mutated tensors + np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens) + np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens) + np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags) + np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step) + np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time) + np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder) + np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder) + np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time) + np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num) + self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0])) + + def test_no_trigger_path(self): + # Make block_tables at candidate index != -1 so nothing triggers + # Candidate index for bid 0/1 is 1, set it to 7 + bt = self.block_tables.numpy() + bt[:, 1] = 7 + self.block_tables = paddle.to_tensor(bt) + + # Reset outputs to distinctive values + self.step_seq_lens_decoder[:] = 0 + self.step_draft_tokens[:] = 0 + self.step_seq_lens_this_time[:] = 0 + self.accept_num[:] = -123 + self.accept_tokens[:] = -777 + self.is_block_step[:] = False + self.not_need_stop[:] = False + + # For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3 + # With stop_nums=5 -> True + speculate_schedule_cache( + self.draft_tokens, + self.block_tables, + self.stop_flags, + self.seq_lens_this_time, + self.seq_lens_decoder, + self.step_seq_lens_decoder, + self.step_draft_tokens, + self.step_seq_lens_this_time, + self.accept_num, + self.accept_tokens, + self.is_block_step, + self.not_need_stop, + self.stop_nums, + self.block_size, + self.max_draft_tokens, + ) + + # Nothing should have changed except not_need_stop + np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy())) + np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy())) + np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777)) + np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123)) + self.assertTrue(bool(self.not_need_stop.numpy()[0])) + + +if __name__ == "__main__": + unittest.main()