From 0b7a5778ab47955bf7061b6d3e22696a9ab67aba Mon Sep 17 00:00:00 2001 From: Jundong Liu <61149469+littledgg@users.noreply.github.com> Date: Mon, 13 Oct 2025 15:21:41 +0800 Subject: [PATCH] [Executor]CUDAGraph support Speculate Decode (#4258) * [Executor]CUDAGraph support Speculate Decode * fix problem * solve problem * fix * fast compile * CUDAGraph + mtp support eb5(only target model) * Revert "fast compile" This reverts commit 3cfe8373edbeafd54a13f1d60ea1f81558aa4a94. * fix precommit * solve comment * fix comment about #pragram unroll --------- Co-authored-by: gongshaotian Co-authored-by: gongshaotian --- .../append_attn/append_attention_func.cuh | 5 +- .../speculate_write_cache_with_rope_impl.cuh | 35 +---- custom_ops/gpu_ops/cpp_extensions.cc | 2 +- .../speculate_get_padding_offset.cu | 1 + .../speculate_decoding/speculate_verify.cu | 32 +++- fastdeploy/config.py | 9 +- fastdeploy/model_executor/forward_meta.py | 1 + .../cudagraph_piecewise_backend.py | 26 ++-- .../graph_optimization_backend.py | 6 +- .../model_executor/layers/sample/sampler.py | 2 + .../model_executor/models/ernie4_5_mtp.py | 10 +- fastdeploy/spec_decode/base.py | 36 ++--- fastdeploy/spec_decode/mtp.py | 141 ++++++++++++------ fastdeploy/spec_decode/ngram.py | 4 +- fastdeploy/worker/gpu_model_runner.py | 87 ++++++++--- fastdeploy/worker/gpu_worker.py | 2 +- 16 files changed, 265 insertions(+), 134 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 768e023de..1eaa98bb3 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2410,6 +2410,9 @@ __global__ void merge_multi_chunks_v2_kernel( __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; + if(bid == -1){ + continue; + } const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) continue; @@ -2427,7 +2430,7 @@ __global__ void merge_multi_chunks_v2_kernel( seq_len_kv += seq_len_q; } const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); - if (num_chunks_this_seq <= 1) { + if (num_chunks_this_seq <= 1 || !ENABLE_PREFILL) { continue; } diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 4fb5c93d0..cd4439557 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -72,6 +72,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; const int token_id = linear_index / hidden_size; const int ori_bi = batch_id_per_token[token_id]; + if (ori_bi == -1) continue; if (seq_lens_decoder[ori_bi] == 0) continue; const int bias = linear_index % hidden_size; const int hi = bias / head_size; // q + k + v @@ -84,15 +85,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + return; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -149,13 +142,13 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( float row_inv_var = Rsqrt(row_variance + rms_norm_eps); if (hi < num_heads) { Load(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); - #pragma unroll +#pragma unroll for (int i = 0; i < VecSize; i++) { bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * q_norm_vec[i]); } } else { Load(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); - #pragma unroll +#pragma unroll for (int i = 0; i < VecSize; i++) { bias_vec[i] = static_cast(tmp_vec[i] * row_inv_var * k_norm_vec[i]); } @@ -390,15 +383,7 @@ __global__ void append_speculate_cache_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + return; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -525,15 +510,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + return; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 1ced2ce6f..079bcd543 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -676,7 +676,7 @@ void SpeculateVerify( const paddle::Tensor &output_cum_offsets, const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode); + int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts); void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu index e37dacbf3..de9b8333d 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu @@ -139,6 +139,7 @@ std::vector SpeculateGetPaddingOffsetInferDtype( PD_BUILD_STATIC_OP(speculate_get_padding_offset) .Inputs({"input_ids", "draft_tokens", + "cum_offsets", "token_num", "seq_len", "seq_lens_encoder"}) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index 0e6e66d00..8ebf4fd48 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -79,7 +79,7 @@ __global__ void speculate_verify( const int *output_cum_offsets, const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop, const bool benchmark_mode) { + const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) { const int bid = threadIdx.x; // verify and set stop flags int accept_num_now = 1; @@ -107,6 +107,24 @@ __global__ void speculate_verify( if (seq_lens_encoder[bid] != 0) { break; } + if (accept_all_drafts) { + // accept all draft tokens + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } else { + accept_num_now++; + } + continue; + } if (USE_TOPK) { if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) { @@ -255,7 +273,7 @@ void SpeculateVerify( const paddle::Tensor &output_cum_offsets, const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) { + int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) { // printf("Enter speculate update\n"); auto bsz = accept_tokens.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0]; @@ -298,7 +316,7 @@ void SpeculateVerify( is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop, benchmark_mode); + prefill_one_step_stop, benchmark_mode, accept_all_drafts); } else { speculate_verify <<<1, BlockSize, 0, accept_tokens.stream()>>>( @@ -314,7 +332,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts); } } else { if (enable_topp) { @@ -332,7 +350,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts); } else { speculate_verify <<<1, BlockSize, 0, accept_tokens.stream()>>>( @@ -348,7 +366,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts); } } @@ -363,7 +381,7 @@ PD_BUILD_STATIC_OP(speculate_verify) "actual_candidate_len", "actual_draft_token_nums", "topp"}) .Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out", "stop_flags_out"}) - .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"}) + .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"}) .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, {"accept_num", "accept_num_out"}, {"step_idx", "step_idx_out"}, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index d906bbaef..5105cc482 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1150,7 +1150,14 @@ class FDConfig: # Initialize cuda graph capture list if self.graph_opt_config.cudagraph_capture_sizes is None: self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) - self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) + + if self.speculative_config is not None and self.speculative_config.method == "mtp": + max_shape = self.parallel_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1) + if max_shape % 2 == 1: + max_shape = max_shape + 1 + self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=min(512, max_shape)) + else: + self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs) # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn if self.graph_opt_config.graph_opt_level == 2: diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 968495733..06ef4b755 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -133,6 +133,7 @@ class ForwardMeta: "shape": obj.shape, "dtype": str(obj.dtype), "place": str(obj.place), + "content": obj if obj.numel() < 70 else "Too big to show", } return tensor_info elif isinstance(obj, (list, tuple)): diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 3465b6092..ce3bedd2e 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import time from contextlib import contextmanager from dataclasses import dataclass, field from typing import Callable, Dict, List, Optional @@ -111,7 +112,7 @@ class CudaGraphPiecewiseBackend: entry.num_finished_warmup += 1 entry.runnable(**kwargs) logger.debug( - f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, " + f"[CUDA GRAPH] [ID:{id(self)}] Warm up for batch size {entry.real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) @@ -138,15 +139,17 @@ class CudaGraphPiecewiseBackend: real_shape = ids_remove_padding.shape[0] padding_real_shape = self.real_shape_to_captured_size[real_shape] logger.debug( - f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, " - f"The padded shape is :{padding_real_shape}" + f"[CUDA GRAPH] [ID:{id(self)}] The actual real shape obtained by CUDAGraph is :{real_shape}, " + f"The padded shape is :{padding_real_shape}, If Padding :{real_shape != padding_real_shape}" ) entry = self.concrete_size_entries.get(padding_real_shape) assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list." if entry.runnable is None: entry.runnable = self.runnable - logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}") + logger.debug( + f"[CUDA GRAPH] [ID:{id(self)}] New entry lazy initialize with real shape {padding_real_shape}" + ) if not entry.use_cudagraph: return entry.runnable(**kwargs) @@ -161,7 +164,7 @@ class CudaGraphPiecewiseBackend: entry.num_finished_warmup += 1 entry.runnable(**kwargs) logger.debug( - f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, " + f"[CUDA GRAPH] [ID:{id(self)}] Warm up for real shape {padding_real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) @@ -196,11 +199,11 @@ class CudaGraphPiecewiseBackend: # For CUDAGraph debug # self._save_cudagrpah_dot_files(entry) - logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}") + logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}") # Replay entry.cuda_graph.replay() - logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") + logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph replayed for real shape {padding_real_shape}") if len(entry.output_buffers) == 1: return entry.output_buffers[0] return entry.output_buffers @@ -214,16 +217,17 @@ class CudaGraphPiecewiseBackend: self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape) logger.info( - f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry." + f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph capture list {self.cudagraph_capture_sizes}, " + "Created all real shape entry." ) def clear_graph(self): """ """ # Clear graphs - for id, entry in self.concrete_size_entries.items(): + for _id, entry in self.concrete_size_entries.items(): if entry.cuda_graph: del entry.cuda_graph - logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.") + logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] The CUDAGraph with shape {_id} has been cleared.") del self.concrete_size_entries paddle.device.cuda.empty_cache() @@ -236,6 +240,6 @@ class CudaGraphPiecewiseBackend: log_dir = envs.FD_LOG_DIR if entry.cuda_graph: entry.cuda_graph.print_to_dot_files( - f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}", + f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}_time{time.perf_counter()}", 1 << 0, ) diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index e843753e8..73fae52e9 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -115,7 +115,7 @@ class GraphOptBackend: self.runnable = runnable self.fd_config = fd_config - self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0] + self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0] if self.fd_config.graph_opt_config.graph_opt_level > 0: # 1. Prepare cuda grpah input buffers (contain output of subgraphs) @@ -138,9 +138,9 @@ class GraphOptBackend: ) assert kwargs["forward_meta"].ids_remove_padding is not None - batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0] + real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0] - if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch): + if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size): return self.runnable(**kwargs) else: return self.cudagraph_piecewise_backend.__call__(**kwargs) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 5aecfa1f9..21c40295a 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -395,6 +395,7 @@ class SpeculativeSampler(nn.Layer): sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + accept_all_drafts: bool = False, ) -> paddle.Tensor: """ """ @@ -451,6 +452,7 @@ class SpeculativeSampler(nn.Layer): self.speculative_verify_window, True, # enable_topp self.speculative_benchmark_mode, + accept_all_drafts, ) return None diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 19123678a..1af97fe91 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -27,6 +27,9 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer @@ -229,6 +232,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel): return mappings +@support_graph_optimization class Ernie4_5_MTPModel(nn.Layer): """ Ernie4_5_MTPModel @@ -435,6 +439,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): """ forward """ - hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta) + hidden_states = self.ernie( + ids_remove_padding=ids_remove_padding, + previous_hidden_states=previous_hidden_states, + forward_meta=forward_meta, + ) return hidden_states diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 900e99c9b..1b8f98384 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -33,31 +33,33 @@ class Proposer(ABC): the speculative decoding framework """ - def __init__(self, cfg: FDConfig): + def __init__(self, fd_config: FDConfig): """ Init Speculative proposer """ - cfg.parallel_config.tp_group = None - cfg.parallel_config.ep_group = None - self.cfg = deepcopy(cfg) - cfg.parallel_config.tp_group = dist.get_group( - cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + fd_config.parallel_config.tp_group = None + fd_config.parallel_config.ep_group = None + self.fd_config = deepcopy(fd_config) + fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) - cfg.parallel_config.ep_group = dist.get_group( - cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET ) - self.cfg.parallel_config.tp_group = dist.get_group( - cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + self.fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) - self.cfg.parallel_config.ep_group = dist.get_group( - cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET + self.fd_config.parallel_config.ep_group = dist.get_group( + fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET ) - self.parallel_config = self.cfg.parallel_config - self.model_config = self.cfg.model_config - self.speculative_config = self.cfg.speculative_config - self.cache_config = self.cfg.cache_config - self.quant_config = self.cfg.quant_config + self.parallel_config = self.fd_config.parallel_config + self.model_config = self.fd_config.model_config + self.speculative_config = self.fd_config.speculative_config + self.cache_config = self.fd_config.cache_config + self.quant_config = self.fd_config.quant_config + self.graph_opt_config = self.fd_config.graph_opt_config + self.scheduler_config = self.fd_config.scheduler_config self.max_num_seqs = self.parallel_config.max_num_seqs self.max_model_len = self.parallel_config.max_model_len diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 92b79b6fb..fb7d32645 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -22,6 +22,7 @@ import paddle from paddleformers.utils.log import logger from fastdeploy import envs +from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend @@ -31,6 +32,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import MTPSampler +from fastdeploy.model_executor.model_loader import get_model_loader +from fastdeploy.model_executor.models import ModelForCasualLM from fastdeploy.model_executor.ops.gpu import ( draft_model_postprocess, draft_model_preprocess, @@ -52,12 +55,19 @@ class MTPProposer(Proposer): Proposer for Multi-Token-Prediction(MTP) """ - def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs): - super().__init__(cfg) + def __init__( + self, + fd_config: FDConfig, + main_model: ModelForCasualLM, + local_rank: int, + device_id: int, # physical device id + target_model_inputs, # main model share inputs + ): + super().__init__(fd_config) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id - self._update_cfg(main_model) + self._update_mtp_config(main_model) self._load_model() self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy @@ -65,16 +75,22 @@ class MTPProposer(Proposer): # [mixed, prefill, decoder] self.role = "mixed" - self.sampler = MTPSampler(cfg) + self.sampler = MTPSampler(fd_config) self._init_model_inputs() + # CUDA Graph + self.use_cudagraph = False # self.graph_opt_config.use_cudagraph + self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) + self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes + self.attn_backends: list[AttentionBackend] = [] self._initialize_attn_backend() - def _update_cfg(self, main_model): + def _update_mtp_config(self, main_model): """ Update config for MTP from global config """ + self.forward_meta: ForwardMeta = None self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP") self.speculative_config.sharing_model = main_model self.model_config.num_hidden_layers = 1 @@ -89,21 +105,18 @@ class MTPProposer(Proposer): """ Load MTP Layer """ - from fastdeploy.model_executor.model_loader import get_model_loader - model_loader = get_model_loader(load_config=self.cfg.load_config) - self.model = model_loader.load_model(fd_config=self.cfg) + model_loader = get_model_loader(load_config=self.fd_config.load_config) + self.model = model_loader.load_model(fd_config=self.fd_config) def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """Set dummy prefill inputs to model_inputs""" max_dec_len = expected_decode_len + 1 - self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() - full_length = min( + input_length = min( num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len, ) - input_length = int(full_length * self.cache_config.kv_cache_ratio) + block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num @@ -125,13 +138,15 @@ class MTPProposer(Proposer): ) self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer - def initialize_kv_cache(self): + def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): """ Initialize kv cache """ - # prompt cache + + self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.cache_kvs = {} + # Get kv cache dtype cache_type = self.parallel_config.dtype kv_cache_quant_type = None @@ -151,9 +166,7 @@ class MTPProposer(Proposer): kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size - if not self.parallel_config.do_profile and ( - self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed" - ): + if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): cache_kvs_list = [] for i in range( self.num_main_model_layers, @@ -230,7 +243,7 @@ class MTPProposer(Proposer): # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( - self.cfg, + self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, @@ -243,7 +256,7 @@ class MTPProposer(Proposer): ) self.attn_backends.append(attn_backend) - def clear_dummy_input(self): + def clear_mtp_cache(self): """ Clear allocated cacheKV """ @@ -251,15 +264,14 @@ class MTPProposer(Proposer): if self.forward_meta is not None: del self.forward_meta.caches - def update_block_num(self, num_gpu_blocks) -> None: + def update_mtp_block_num(self, num_gpu_blocks) -> None: """ - Update block num by theoretical calculation + Update MTP block num by theoretical calculation """ + # Reset block table and kv cache with global block num self.main_model_num_gpu_blocks = num_gpu_blocks - self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) - if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks) # Reset free list free_list = list( @@ -276,7 +288,6 @@ class MTPProposer(Proposer): "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), } ) - self.parallel_config.do_profile = False def _init_model_inputs(self): """ @@ -300,6 +311,8 @@ class MTPProposer(Proposer): self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) + self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"]) + self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) @@ -308,6 +321,9 @@ class MTPProposer(Proposer): self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( self.target_model_inputs["decoder_tile_ids_per_batch"] ) + self.model_inputs["target_hidden_states"] = paddle.full( + [self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16" + ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) self.model_inputs["rope_emb"] = get_rope( @@ -443,9 +459,6 @@ class MTPProposer(Proposer): """ Process inputs for prefill tasks and insert it to model_inputs buffer """ - # NOTE: Lazy initialize kv cache - if "caches" not in self.model_inputs: - self.initialize_kv_cache() # TODO:Init role in initialize process if req_dicts[-1].disaggregate_info is not None: @@ -526,7 +539,7 @@ class MTPProposer(Proposer): request.get("block_tables"), dtype="int32" ) self.model_inputs["not_need_stop"][0] = True - self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer def _initialize_forward_meta(self): """ @@ -556,6 +569,33 @@ class MTPProposer(Proposer): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) + # Update Batch type for cuda graph + only_decode_batch = True + prefill_exists = None + + # Mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + + self.forward_meta.step_use_cudagraph = ( + self.use_cudagraph + and only_decode_batch + and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + ) + + def exist_prefill(self): + """ + check whether prefill stage exist + """ + if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0: + return 1 + else: + return 0 + def _prepare_inputs(self, full_hidden_states): """ Prepare MTP inputs @@ -599,10 +639,8 @@ class MTPProposer(Proposer): self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) - if isinstance(target_hidden_states, list): - target_hidden_states = target_hidden_states[0] - return target_hidden_states + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def _post_process(self, sampled_token_ids): """ @@ -633,7 +671,7 @@ class MTPProposer(Proposer): self.parallel_config.use_ep, ) - def _propose(self, target_hidden_states): + def _propose(self): """ Main process for MTP inference """ @@ -663,10 +701,15 @@ class MTPProposer(Proposer): self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # for speculative decoding - self.model_inputs["output_cum_offsets"] = output_cum_offsets - self.model_inputs["output_padding_offset"] = output_padding_offset + self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) + self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) + + # Initialize forward meta data self._initialize_forward_meta() + # Padding inputs for cuda graph + self.padding_cudagraph_inputs() + # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], @@ -687,9 +730,11 @@ class MTPProposer(Proposer): model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], - previous_hidden_states=target_hidden_states, + previous_hidden_states=self.model_inputs["target_hidden_states"], forward_meta=self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, @@ -721,7 +766,7 @@ class MTPProposer(Proposer): self._post_process(sampled_token_ids) if substep != self.num_model_steps - 1: - target_hidden_states = self._get_self_hidden_states(hidden_states) + self._get_self_hidden_states(hidden_states) else: if hasattr(self.model, "empty_input_forward"): self.model.empty_input_forward() @@ -733,10 +778,7 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_this_time"], self.model_inputs["step_idx"], ) - if isinstance(target_hidden_states, list): - target_hidden_states = target_hidden_states[0] - - return target_hidden_states + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def update_task_chunk_prefill(self, task): """ @@ -821,8 +863,8 @@ class MTPProposer(Proposer): def _run_impl(self, full_hidden_states): """""" - target_hidden_states = self._prepare_inputs(full_hidden_states) - self._propose(target_hidden_states=target_hidden_states) + self._prepare_inputs(full_hidden_states) + self._propose() self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() @@ -830,3 +872,16 @@ class MTPProposer(Proposer): def is_chunk_prefill_enabled(self): """""" return True + + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # In init_attention_metadata, the decode buffer has already been cleared + + # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. + if self.use_cudagraph: + self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] + return diff --git a/fastdeploy/spec_decode/ngram.py b/fastdeploy/spec_decode/ngram.py index 833a45f54..241a61d22 100644 --- a/fastdeploy/spec_decode/ngram.py +++ b/fastdeploy/spec_decode/ngram.py @@ -29,8 +29,8 @@ class NgramProposer(Proposer): Matching corresponding tokens in input and output as draft tokens. """ - def __init__(self, cfg: FDConfig): - super().__init__(cfg) + def __init__(self, fd_config: FDConfig): + super().__init__(fd_config) self.max_ngram_size = self.speculative_config.max_ngram_size self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index e34f23b16..3c09c6c9d 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -135,7 +135,6 @@ class GPUModelRunner(ModelRunnerBase): # self.kv_caches: list[paddle.Tensor] = [] # Cuda Graph - self.graph_opt_level = self.graph_opt_config.graph_opt_level self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes @@ -155,7 +154,7 @@ class GPUModelRunner(ModelRunnerBase): # In the future, we will expand it as a list. self.attn_backends: list[AttentionBackend] = [] # self.attn_metadatas: list[AttentionMetadata] = [] - self.initialize_attn_backend() + self._initialize_attn_backend() # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None @@ -876,7 +875,6 @@ class GPUModelRunner(ModelRunnerBase): ) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) - self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -890,6 +888,7 @@ class GPUModelRunner(ModelRunnerBase): # Initialize forward meta data self.initialize_forward_meta() + self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) # Get sampling metadata self.sampling_metadata = SamplingMetadata( @@ -992,7 +991,6 @@ class GPUModelRunner(ModelRunnerBase): Initialize kv cache """ cache_kvs = {} - max_block_num = self.num_gpu_blocks # Get kv cache dtype cache_type = self.parallel_config.dtype @@ -1008,7 +1006,7 @@ class GPUModelRunner(ModelRunnerBase): # Get kv cache shape kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( - max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type + max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] @@ -1055,7 +1053,7 @@ class GPUModelRunner(ModelRunnerBase): del value paddle.device.cuda.empty_cache() - def initialize_attn_backend(self) -> None: + def _initialize_attn_backend(self) -> None: """ Initialize attention backends """ @@ -1099,6 +1097,7 @@ class GPUModelRunner(ModelRunnerBase): batch_size: paddle.Tensor, expected_decode_len: int = 1, in_capturing: bool = False, + accept_all_drafts: bool = False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1106,6 +1105,7 @@ class GPUModelRunner(ModelRunnerBase): num_tokens: expected_decode_len: Expected number of tokens generated in_capturing: Is cuda graph in capturing state + accept_all_drafts: Target model will accept all draft tokens """ self._dummy_prefill_inputs( num_tokens=num_tokens, @@ -1134,12 +1134,16 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["image_features"], self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = model_output else: model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, @@ -1179,6 +1183,7 @@ class GPUModelRunner(ModelRunnerBase): self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs, + accept_all_drafts, ) sampler_output = None if self.parallel_config.tensor_parallel_size > 1: @@ -1339,14 +1344,55 @@ class GPUModelRunner(ModelRunnerBase): time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - for batch_size in sorted(capture_sizes, reverse=True): - self._dummy_run( - num_tokens=self.parallel_config.max_num_batched_tokens, - batch_size=batch_size, - in_capturing=True, - expected_decode_len=expected_decode_len, - ) - logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") + + if self.speculative_decoding and self.speculative_method == "mtp": + # Capture Target Model without bsz 1 + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture target model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=1, + ) + logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}") + # Capture Draft Model without bsz 1 + # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture Draft model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=True, + ) + logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") + # Capture Draft Model with bsz 1 + if 1 in capture_sizes: + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=int(1), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=False, + ) + logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") + else: + for batch_size in sorted(capture_sizes, reverse=True): + self._dummy_run( + num_tokens=self.parallel_config.max_num_batched_tokens, + batch_size=batch_size, + in_capturing=True, + expected_decode_len=expected_decode_len, + ) + logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") @@ -1427,12 +1473,16 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["image_features"], self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = model_output else: model_output = self.model( ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, self.share_inputs["cu_seqlens_q"], @@ -1628,20 +1678,22 @@ class GPUModelRunner(ModelRunnerBase): # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.parallel_config.total_block_num self.initialize_kv_cache(profile=True) + if self.speculative_method in ["mtp"]: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) # 1. Profile with multimodal encoder & encoder cache # 2. Dummy run self._dummy_run( num_tokens=self.parallel_config.max_num_batched_tokens, - batch_size=min(self.parallel_config.max_num_seqs, 3), + batch_size=self.parallel_config.max_num_seqs, ) # 3. gc self.clear_cache() if self.speculative_method in ["mtp"]: - self.proposer.clear_dummy_input() + self.proposer.clear_mtp_cache() def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ @@ -1671,7 +1723,7 @@ class GPUModelRunner(ModelRunnerBase): ) if self.speculative_method in ["mtp"]: - self.proposer.update_block_num(num_gpu_blocks) + self.proposer.update_mtp_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): """ @@ -1756,6 +1808,7 @@ class GPUModelRunner(ModelRunnerBase): # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. if self.use_cudagraph: self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return def _init_image_preprocess(self) -> None: diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index e7b1adb4b..fea969090 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -207,7 +207,7 @@ class GpuWorker(WorkerBase): """ Perform the warm-up and the graph optimization """ - if self.model_runner.graph_opt_level >= 1: + if self.fd_config.graph_opt_config.graph_opt_level >= 1: self.model_runner.sot_warmup() # Triger cuda grpah capture self.model_runner.capture_model()