diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 5e4ce35da..7d14ef51f 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -494,12 +494,12 @@ std::vector AppendAttention( paddle::Tensor fmha_out; if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - fmha_out = GetEmptyTensor( + fmha_out = paddle::zeros( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, paddle::DataType::INT8, qkv.place()); } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - fmha_out = GetEmptyTensor( + fmha_out = paddle::zeros( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, paddle::DataType::FLOAT8_E4M3FN, qkv.place()); @@ -507,7 +507,7 @@ std::vector AppendAttention( PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); } } else { - fmha_out = GetEmptyTensor( + fmha_out = paddle::zeros( {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, dtype_id, qkv.place()); diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 24787e8b7..cf35e6b5b 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2418,6 +2418,9 @@ __global__ void merge_multi_chunks_v2_kernel( __shared__ float md_smem[bdy * 2]; for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { const uint32_t bid = batch_id_per_token[qid]; + if(bid == -1){ + continue; + } const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const int seq_len_q = seq_lens_q[bid]; if (seq_len_q == 0) continue; @@ -2437,6 +2440,8 @@ __global__ void merge_multi_chunks_v2_kernel( const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); if (num_chunks_this_seq <= 1) { continue; + }else if (!ENABLE_PREFILL){ + continue; } using LoadT = AlignedVector; diff --git a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh index 4fb5c93d0..c5c8eca00 100644 --- a/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_impl.cuh @@ -84,15 +84,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + return ; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -390,15 +382,7 @@ __global__ void append_speculate_cache_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + return ; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; @@ -525,15 +509,7 @@ __global__ void append_speculate_cache_neox_rope_kernel( const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int block_idx = block_table_now[write_seq_id / block_size]; if (block_idx < 0) { - printf( - "Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var " - "%d %d %d %d\n", - block_idx, - write_seq_id, - ori_bi, - seq_lens_decoder[ori_bi], - token_id, - cu_seqlens_q[ori_bi]); + return ; // NOTE(gongshaotian): For CUDAGraph padding } const int block_offset = write_seq_id % block_size; diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 3379e0cb7..ce4fa1420 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -684,7 +684,7 @@ void SpeculateVerify( const paddle::Tensor &output_cum_offsets, const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode); + int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts); void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, diff --git a/custom_ops/gpu_ops/rebuild_padding.cu b/custom_ops/gpu_ops/rebuild_padding.cu index 93c1bb38c..d947091ab 100644 --- a/custom_ops/gpu_ops/rebuild_padding.cu +++ b/custom_ops/gpu_ops/rebuild_padding.cu @@ -130,7 +130,6 @@ std::vector rebuild_padding( int pack_num = elem_nums / PackSize; const int blocksize = 128; const int grid_size = (pack_num + blocksize - 1) / blocksize; - if (output_padding_offset) { RebuildAppendPaddingKernel <<>>( diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu index 07f33ee2d..de9b8333d 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_get_padding_offset.cu @@ -139,7 +139,7 @@ std::vector SpeculateGetPaddingOffsetInferDtype( PD_BUILD_STATIC_OP(speculate_get_padding_offset) .Inputs({"input_ids", "draft_tokens", - "cum_offsets" + "cum_offsets", "token_num", "seq_len", "seq_lens_encoder"}) diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu index aa6235687..bab431d71 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu @@ -73,7 +73,7 @@ __global__ void speculate_verify( const int *output_cum_offsets, const int *actual_candidate_len, const int real_bsz, const int max_draft_tokens, const int end_length, const int max_seq_len, const int max_candidate_len, const int verify_window, - const bool prefill_one_step_stop, const bool benchmark_mode) { + const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) { const int bid = threadIdx.x; // verify and set stop flags int accept_num_now = 1; @@ -101,6 +101,24 @@ __global__ void speculate_verify( if (seq_lens_encoder[bid] != 0) { break; } + if (accept_all_drafts) { + // accept all draft tokens + step_idx[bid]++; + auto accept_token = draft_tokens_now[i + 1]; + accept_tokens[bid * max_draft_tokens + i] = accept_token; + + if (is_in_end(accept_token, end_tokens, end_length) || + step_idx[bid] >= max_dec_len[bid]) { + stop_flags[bid] = true; + stop_flag_now_int = 1; + if (step_idx[bid] >= max_dec_len[bid]) + accept_tokens[bid * max_draft_tokens + i] = end_tokens[0]; + break; + } else { + accept_num_now++; + } + continue; + } if (USE_TOPK) { if (verify_tokens_now[i * max_candidate_len] == draft_tokens_now[i + 1]) { @@ -249,7 +267,7 @@ void SpeculateVerify( const paddle::Tensor &output_cum_offsets, const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, - int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) { + int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) { // printf("Enter speculate update\n"); auto bsz = accept_tokens.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0]; @@ -292,7 +310,7 @@ void SpeculateVerify( is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, max_candidate_len, verify_window, - prefill_one_step_stop, benchmark_mode); + prefill_one_step_stop, benchmark_mode, accept_all_drafts); } else { speculate_verify <<<1, BlockSize, 0, accept_tokens.stream()>>>( @@ -308,7 +326,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts); } } else { if (enable_topp) { @@ -326,7 +344,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts); } else { speculate_verify <<<1, BlockSize, 0, accept_tokens.stream()>>>( @@ -342,7 +360,7 @@ void SpeculateVerify( end_tokens.data(), is_block_step.data(), output_cum_offsets.data(), actual_candidate_len.data(), real_bsz, max_draft_tokens, end_length, max_seq_len, - max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); + max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts); } } @@ -357,7 +375,7 @@ PD_BUILD_STATIC_OP(speculate_verify) "actual_candidate_len", "actual_draft_token_nums", "topp"}) .Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out", "stop_flags_out"}) - .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"}) + .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"}) .SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, {"accept_num", "accept_num_out"}, {"step_idx", "step_idx_out"}, diff --git a/fastdeploy/config.py b/fastdeploy/config.py index bd61e62b8..4c3530512 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1437,6 +1437,11 @@ class FDConfig: if self.graph_opt_config.cudagraph_only_prefill: self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512) + elif self.speculative_config is not None and self.speculative_config.method == "mtp": + max_shape = self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1) + if max_shape % 2 == 1: + max_shape = max_shape + 1 + self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=min(512, max_shape)) else: self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.scheduler_config.max_num_seqs) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 10608676c..1ea46785e 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -167,7 +167,7 @@ class ForwardMeta: "shape": obj.shape, "dtype": str(obj.dtype), "place": str(obj.place), - # "content": obj if obj.numel()<10 else "Too big to show" + "content": obj if obj.numel() < 70 else "Too big to show", } return tensor_info elif isinstance(obj, (list, tuple)): diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index d31cf7464..6341d3d71 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -121,7 +121,7 @@ class CudaGraphPiecewiseBackend: entry.num_finished_warmup += 1 entry.runnable(**kwargs) logger.debug( - f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, " + f"[CUDA GRAPH][ID:{id(self)}] Warm up for batch size {entry.real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) @@ -148,15 +148,15 @@ class CudaGraphPiecewiseBackend: real_shape = ids_remove_padding.shape[0] padding_real_shape = self.real_shape_to_captured_size[real_shape] logger.debug( - f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, " - f"The padded shape is :{padding_real_shape}" + f"[CUDA GRAPH][ID:{id(self)}] The actual real shape obtained by CUDAGraph is :{real_shape}, " + f"The padded shape is :{padding_real_shape}, If Padding :{real_shape != padding_real_shape}" ) entry = self.concrete_size_entries.get(padding_real_shape) assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list." if entry.runnable is None: entry.runnable = self.runnable - logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}") + logger.debug(f"[CUDA GRAPH][ID:{id(self)}] New entry lazy initialize with real shape {padding_real_shape}") if not entry.use_cudagraph: return entry.runnable(**kwargs) @@ -171,7 +171,7 @@ class CudaGraphPiecewiseBackend: entry.num_finished_warmup += 1 entry.runnable(**kwargs) logger.debug( - f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, " + f"[CUDA GRAPH][ID:{id(self)}] Warm up for real shape {padding_real_shape}, " f"finished ({n + 1}/{entry.num_finished_warmup}) times" ) @@ -206,11 +206,11 @@ class CudaGraphPiecewiseBackend: # For CUDAGraph debug # self._save_cudagrpah_dot_files(entry) - logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}") + logger.debug(f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}") # Replay entry.cuda_graph.replay() - logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") + logger.debug(f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph replayed for real shape {padding_real_shape}") if len(entry.output_buffers) == 1: return entry.output_buffers[0] return entry.output_buffers @@ -223,18 +223,19 @@ class CudaGraphPiecewiseBackend: for shape in self.cudagraph_capture_sizes: self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape) - logger.info( - f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry." + logger.debug( + f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph capture list {self.cudagraph_capture_sizes}, " + "Created all real shape entry." ) def clear_graph(self): """ """ # Clear graphs custom_ar_clear_ipc_handles() - for id, entry in self.concrete_size_entries.items(): + for _id, entry in self.concrete_size_entries.items(): if entry.cuda_graph: del entry.cuda_graph - logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.") + logger.debug(f"[CUDA GRAPH][ID:{id(self)}] The CUDAGraph with shape {_id} has been cleared.") del self.concrete_size_entries paddle.device.cuda.empty_cache() diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 5ebc82fb1..7c54f52e2 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -115,7 +115,7 @@ class GraphOptBackend: self.runnable = runnable self.fd_config = fd_config - self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0] + self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0] if self.fd_config.graph_opt_config.graph_opt_level > 0: # 1. Prepare cuda graph input buffers (contain output of subgraphs) @@ -138,9 +138,9 @@ class GraphOptBackend: ) assert kwargs["forward_meta"].ids_remove_padding is not None - batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0] + real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0] - if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch): + if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size): return self.runnable(**kwargs) else: return self.cudagraph_piecewise_backend.__call__(**kwargs) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 334dcc80f..c97221723 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -461,6 +461,7 @@ class SpeculativeSampler(nn.Layer): sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], + accept_all_drafts: bool = False, ) -> paddle.Tensor: """ """ @@ -517,6 +518,7 @@ class SpeculativeSampler(nn.Layer): self.speculative_verify_window, True, # enable_topp self.speculative_benchmark_mode, + accept_all_drafts, ) return None diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index c9bef3700..fa8bb8e10 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -28,6 +28,9 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.decorator import ( + support_graph_optimization, +) from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer @@ -234,6 +237,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel): return mappings +@support_graph_optimization class Ernie4_5_MTPModel(nn.Layer): """ Ernie4_5_MTPModel @@ -457,6 +461,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM): """ forward """ - hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta) + hidden_states = self.ernie( + ids_remove_padding=ids_remove_padding, + previous_hidden_states=previous_hidden_states, + forward_meta=forward_meta, + ) return hidden_states diff --git a/fastdeploy/spec_decode/base.py b/fastdeploy/spec_decode/base.py index 5c395b0d5..a7d8f2266 100644 --- a/fastdeploy/spec_decode/base.py +++ b/fastdeploy/spec_decode/base.py @@ -33,24 +33,25 @@ class Proposer(ABC): the speculative decoding framework """ - def __init__(self, cfg: FDConfig): + def __init__(self, fd_config: FDConfig): """ Init Speculative proposer """ - cfg.parallel_config.tp_group = None - self.cfg = deepcopy(cfg) - cfg.parallel_config.tp_group = dist.get_group( - cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + fd_config.parallel_config.tp_group = None + self.fd_config = deepcopy(fd_config) + fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) - self.cfg.parallel_config.tp_group = dist.get_group( - cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET + self.fd_config.parallel_config.tp_group = dist.get_group( + fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET ) - self.parallel_config = self.cfg.parallel_config - self.model_config = self.cfg.model_config - self.speculative_config = self.cfg.speculative_config - self.cache_config = self.cfg.cache_config - self.quant_config = self.cfg.quant_config - self.scheduler_config = self.cfg.scheduler_config + self.parallel_config = self.fd_config.parallel_config + self.model_config = self.fd_config.model_config + self.speculative_config = self.fd_config.speculative_config + self.cache_config = self.fd_config.cache_config + self.quant_config = self.fd_config.quant_config + self.graph_opt_config = self.fd_config.graph_opt_config + self.scheduler_config = self.fd_config.scheduler_config self.max_num_seqs = self.scheduler_config.max_num_seqs self.max_model_len = self.parallel_config.max_model_len diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 349b8ce4d..945962d55 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -19,7 +19,6 @@ from typing import List import numpy as np import paddle -from paddle import nn from paddleformers.utils.log import logger from fastdeploy import envs @@ -33,6 +32,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import MTPSampler +from fastdeploy.model_executor.model_loader import get_model_loader +from fastdeploy.model_executor.models import ModelForCasualLM from fastdeploy.model_executor.ops.gpu import ( draft_model_postprocess, draft_model_preprocess, @@ -54,12 +55,19 @@ class MTPProposer(Proposer): Proposer for Multi-Token-Prediction(MTP) """ - def __init__(self, cfg: FDConfig, main_model: nn.Layer, local_rank: int, device_id: int, target_model_inputs): - super().__init__(cfg) + def __init__( + self, + fd_config: FDConfig, + main_model: ModelForCasualLM, + local_rank: int, + device_id: int, # physical device id + target_model_inputs, # main model share inputs + ): + super().__init__(fd_config) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id - self._update_cfg(main_model) + self._update_mtp_config(main_model) self._load_model() self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy @@ -67,13 +75,22 @@ class MTPProposer(Proposer): # [mixed, prefill, decoder] self.role = "mixed" - self.sampler = MTPSampler(cfg) + + self.sampler = MTPSampler(fd_config) self._init_model_inputs() + # CUDA Graph + self.use_cudagraph = self.graph_opt_config.use_cudagraph + self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) + self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes + self.attn_backends: list[AttentionBackend] = [] self._initialize_attn_backend() - def _update_cfg(self, main_model): + # Forward meta store the global meta information of the forward + self.forward_meta: ForwardMeta = None + + def _update_mtp_config(self, main_model): """ Update config for MTP from global config """ @@ -91,21 +108,17 @@ class MTPProposer(Proposer): """ Load MTP Layer """ - from fastdeploy.model_executor.model_loader import get_model_loader - - model_loader = get_model_loader(load_config=self.cfg.load_config) - self.model = model_loader.load_model(fd_config=self.cfg) + model_loader = get_model_loader(load_config=self.fd_config.load_config) + self.model = model_loader.load_model(fd_config=self.fd_config) def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """Set dummy prefill inputs to model_inputs""" max_dec_len = expected_decode_len + 1 - self.num_gpu_blocks = self.parallel_config.total_block_num - self.initialize_kv_cache() - full_length = min( + + input_length = min( num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len, ) - input_length = int(full_length * self.cache_config.kv_cache_ratio) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num @@ -127,15 +140,15 @@ class MTPProposer(Proposer): ) self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer - def initialize_kv_cache(self): + def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): """ Initialize kv cache """ - # prompt cache + self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.cache_kvs = {} + # Get kv cache dtype cache_type = self.parallel_config.dtype - kv_cache_quant_type = None if ( self.quant_config @@ -149,7 +162,7 @@ class MTPProposer(Proposer): kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) - if not self.parallel_config.do_profile and ( + if not profile and ( self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" ): cache_kvs_list = [] @@ -239,7 +252,7 @@ class MTPProposer(Proposer): # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( - self.cfg, + self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, @@ -252,7 +265,7 @@ class MTPProposer(Proposer): ) self.attn_backends.append(attn_backend) - def clear_dummy_input(self): + def clear_mtp_cache(self): """ Clear allocated cacheKV """ @@ -260,15 +273,13 @@ class MTPProposer(Proposer): if self.forward_meta is not None: del self.forward_meta.caches - def update_block_num(self, num_gpu_blocks) -> None: + def update_mtp_block_num(self, num_gpu_blocks) -> None: """ - Update block num by theoretical calculation + Update MTP block num by theoretical calculation """ - + # Reset block table and kv cache with global block num self.main_model_num_gpu_blocks = num_gpu_blocks - self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) - if not (self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"): - self.initialize_kv_cache() + self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks) # Reset free list free_list = list( @@ -285,7 +296,6 @@ class MTPProposer(Proposer): "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), } ) - self.parallel_config.do_profile = False def _init_model_inputs(self): """ @@ -309,14 +319,20 @@ class MTPProposer(Proposer): self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) + self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"]) + self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"]) self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"]) + self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( self.target_model_inputs["decoder_tile_ids_per_batch"] ) + self.model_inputs["target_hidden_states"] = paddle.full( + [self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16" + ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) self.model_inputs["rope_emb"] = get_rope( @@ -457,10 +473,6 @@ class MTPProposer(Proposer): """ Process inputs for prefill tasks and insert it to model_inputs buffer """ - # NOTE: Lazy initialize kv cache - if "caches" not in self.model_inputs: - self.initialize_kv_cache() - # TODO:Init role in initialize process if req_dicts[-1].disaggregate_info is not None: if req_dicts[-1].disaggregate_info["role"] == "prefill": @@ -539,7 +551,7 @@ class MTPProposer(Proposer): request.get("block_tables"), dtype="int32" ) self.model_inputs["not_need_stop"][0] = True - self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] + self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer def _initialize_forward_meta(self): """ @@ -578,6 +590,33 @@ class MTPProposer(Proposer): for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) + # Update Batch type for cuda graph + only_decode_batch = True + prefill_exists = None + + # Mix ep in single node + if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed": + only_decode_batch_list = [] + prefill_exists = self.exist_prefill() + paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists) + only_decode_batch = all(only_decode_batch_list) + self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill" + + self.forward_meta.step_use_cudagraph = ( + self.use_cudagraph + and only_decode_batch + and not (prefill_exists if prefill_exists is not None else self.exist_prefill()) + ) + + def exist_prefill(self): + """ + check whether prefill stage exist + """ + if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0: + return 1 + else: + return 0 + def _prepare_inputs(self, full_hidden_states): """ Prepare MTP inputs @@ -621,10 +660,8 @@ class MTPProposer(Proposer): self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) - if isinstance(target_hidden_states, list): - target_hidden_states = target_hidden_states[0] - return target_hidden_states + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def _post_process(self, sampled_token_ids): """ @@ -655,7 +692,7 @@ class MTPProposer(Proposer): self.parallel_config.use_ep, ) - def _propose(self, target_hidden_states): + def _propose(self): """ Main process for MTP inference """ @@ -684,11 +721,17 @@ class MTPProposer(Proposer): self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) - # for speculative decoding - self.model_inputs["output_cum_offsets"] = output_cum_offsets - self.model_inputs["output_padding_offset"] = output_padding_offset + + # For speculative decoding + self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) + self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) + + # Initialize forward meta data self._initialize_forward_meta() + # Padding inputs for cuda graph + self.padding_cudagraph_inputs() + # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], @@ -709,10 +752,11 @@ class MTPProposer(Proposer): model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], - previous_hidden_states=target_hidden_states, + previous_hidden_states=self.model_inputs["target_hidden_states"], forward_meta=self.forward_meta, ) - + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, self.model_inputs["cu_seqlens_q"], @@ -737,9 +781,8 @@ class MTPProposer(Proposer): paddle.distributed.broadcast(sampled_token_ids, 0) self._post_process(sampled_token_ids) - if substep != self.num_model_steps - 1: - target_hidden_states = self._get_self_hidden_states(hidden_states) + self._get_self_hidden_states(hidden_states) def _get_self_hidden_states(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states( @@ -748,10 +791,7 @@ class MTPProposer(Proposer): self.model_inputs["seq_lens_this_time"], self.model_inputs["step_idx"], ) - if isinstance(target_hidden_states, list): - target_hidden_states = target_hidden_states[0] - - return target_hidden_states + self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def update_task_chunk_prefill(self, task): """ @@ -836,8 +876,8 @@ class MTPProposer(Proposer): def _run_impl(self, full_hidden_states): """""" - target_hidden_states = self._prepare_inputs(full_hidden_states) - self._propose(target_hidden_states=target_hidden_states) + self._prepare_inputs(full_hidden_states) + self._propose() self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() @@ -845,3 +885,16 @@ class MTPProposer(Proposer): def is_chunk_prefill_enabled(self): """""" return True + + def padding_cudagraph_inputs(self) -> None: + """ + Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. + In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. + """ + # In init_attention_metadata, the decode buffer has already been cleared + + # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. + if self.use_cudagraph: + self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] + return diff --git a/fastdeploy/spec_decode/ngram.py b/fastdeploy/spec_decode/ngram.py index 833a45f54..241a61d22 100644 --- a/fastdeploy/spec_decode/ngram.py +++ b/fastdeploy/spec_decode/ngram.py @@ -29,8 +29,8 @@ class NgramProposer(Proposer): Matching corresponding tokens in input and output as draft tokens. """ - def __init__(self, cfg: FDConfig): - super().__init__(cfg) + def __init__(self, fd_config: FDConfig): + super().__init__(fd_config) self.max_ngram_size = self.speculative_config.max_ngram_size self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index f508a9e84..ece6331f1 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -124,6 +124,7 @@ class GPUModelRunner(ModelRunnerBase): "matmul_v2", "fused_gemm_epilogue", ] + # Sampler if not self.speculative_decoding: self.sampler = Sampler(fd_config) @@ -138,8 +139,7 @@ class GPUModelRunner(ModelRunnerBase): # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] - # Cuda Graph - self.graph_opt_level = self.graph_opt_config.graph_opt_level + # CUDA Graph self.use_cudagraph = self.graph_opt_config.use_cudagraph self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes @@ -160,7 +160,7 @@ class GPUModelRunner(ModelRunnerBase): # In the future, we will expand it as a list. self.attn_backends: list[AttentionBackend] = [] # self.attn_metadatas: list[AttentionMetadata] = [] - self.initialize_attn_backend() + self._initialize_attn_backend() # Forward meta store the global meta information of the forward self.forward_meta: ForwardMeta = None @@ -1021,7 +1021,6 @@ class GPUModelRunner(ModelRunnerBase): self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) # NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions self.share_inputs["batch_id_per_token"][:] = -1 - self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) @@ -1035,6 +1034,7 @@ class GPUModelRunner(ModelRunnerBase): # Initialize forward meta data self.initialize_forward_meta() + self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) # Get sampling metadata self.sampling_metadata = SamplingMetadata( @@ -1152,7 +1152,6 @@ class GPUModelRunner(ModelRunnerBase): # Get kv cache dtype cache_type = self.parallel_config.dtype - kv_cache_quant_type = None if ( self.quant_config @@ -1242,7 +1241,7 @@ class GPUModelRunner(ModelRunnerBase): paddle.device.cuda.empty_cache() - def initialize_attn_backend(self) -> None: + def _initialize_attn_backend(self) -> None: """ Initialize attention backends """ @@ -1312,6 +1311,7 @@ class GPUModelRunner(ModelRunnerBase): expected_decode_len: int = 1, in_capturing: bool = False, capture_prefill: bool = False, + accept_all_drafts: bool = False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -1320,6 +1320,7 @@ class GPUModelRunner(ModelRunnerBase): expected_decode_len: Expected number of tokens generated in_capturing: Is cuda graph in capturing state capture_prefill: Capture pure prefill for cuda graph + accept_all_drafts: Target model will accept all draft tokens """ input_length_list, max_dec_len_list, block_num = self.get_input_length_list( @@ -1339,8 +1340,8 @@ class GPUModelRunner(ModelRunnerBase): batch_size=batch_size, expected_decode_len=expected_decode_len, ) - while True: + while True: # 1. Initialize forward meta and attention meta data self._prepare_inputs() @@ -1360,6 +1361,8 @@ class GPUModelRunner(ModelRunnerBase): ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, @@ -1404,6 +1407,7 @@ class GPUModelRunner(ModelRunnerBase): self.sampling_metadata, self.parallel_config.max_model_len, self.share_inputs, + accept_all_drafts, ) sampler_output = None if self.parallel_config.tensor_parallel_size > 1: @@ -1470,7 +1474,6 @@ class GPUModelRunner(ModelRunnerBase): skip_save_output=True, zmq_client=self.zmq_client, ) - if self.speculative_decoding: if self.speculative_method == "mtp": self.proposer.run(full_hidden_states=model_output) @@ -1565,7 +1568,6 @@ class GPUModelRunner(ModelRunnerBase): time_before_capture = time.perf_counter() expected_decode_len = 1 capture_sizes = self.cudagraph_capture_sizes.copy() - if self.fd_config.graph_opt_config.cudagraph_only_prefill: for num_tokens in sorted(capture_sizes, reverse=True): self._dummy_run( @@ -1578,6 +1580,46 @@ class GPUModelRunner(ModelRunnerBase): logger.info( f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}" ) + elif self.speculative_decoding and self.speculative_method == "mtp": + # Capture Target Model without bsz 1 + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture target model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=1, + ) + logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}") + # Capture Draft Model without bsz 1 + # NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph + for batch_size in sorted(capture_sizes, reverse=True): + if batch_size == 1: + logger.info("Skip token_num = 1, when capture Draft model for mtp") + else: + assert batch_size % 2 == 0 + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(batch_size / 2), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=True, + ) + logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") + # Capture Draft Model with bsz 1 + if 1 in capture_sizes: + self._dummy_run( + num_tokens=self.scheduler_config.max_num_batched_tokens, + batch_size=int(1), + in_capturing=True, + expected_decode_len=3, + accept_all_drafts=False, + ) + logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}") + else: for batch_size in sorted(capture_sizes, reverse=True): self._dummy_run( @@ -1586,9 +1628,7 @@ class GPUModelRunner(ModelRunnerBase): in_capturing=True, expected_decode_len=expected_decode_len, ) - logger.info( - f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}" - ) + logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") time_after_capture = time.perf_counter() logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") @@ -1674,6 +1714,8 @@ class GPUModelRunner(ModelRunnerBase): ids_remove_padding=self.share_inputs["ids_remove_padding"], forward_meta=self.forward_meta, ) + if self.use_cudagraph: + model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, self.share_inputs["cu_seqlens_q"], @@ -1872,25 +1914,25 @@ class GPUModelRunner(ModelRunnerBase): @profile_run_guard(True) def profile_run(self) -> None: """Execute a forward pass with dummy inputs to profile the memory usage of the model""" - # Initialize kv cache for profile run. After profile run kv cache will be reset. # TODO(gongshaotian): Optimize the management logic of kvcache self.num_gpu_blocks = self.parallel_config.total_block_num self.initialize_kv_cache(profile=True) + if self.speculative_method in ["mtp"]: + self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) # 1. Profile with multimodal encoder & encoder cache # 2. Dummy run self._dummy_run( num_tokens=self.scheduler_config.max_num_batched_tokens, - batch_size=min(self.scheduler_config.max_num_seqs, 3), + batch_size=self.scheduler_config.max_num_seqs, ) # 3. gc self.clear_cache() - if self.speculative_method in ["mtp"]: - self.proposer.clear_dummy_input() + self.proposer.clear_mtp_cache() def update_share_input_block_num(self, num_gpu_blocks: int) -> None: """ @@ -1920,7 +1962,7 @@ class GPUModelRunner(ModelRunnerBase): ) if self.speculative_method in ["mtp"]: - self.proposer.update_block_num(num_gpu_blocks) + self.proposer.update_mtp_block_num(num_gpu_blocks) def cal_theortical_kvcache(self): """ @@ -2017,6 +2059,7 @@ class GPUModelRunner(ModelRunnerBase): # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. if self.use_cudagraph: self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer + self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return def _init_image_preprocess(self) -> None: diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index 23bd4788e..bf1b88dab 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -207,7 +207,7 @@ class GpuWorker(WorkerBase): """ Perform the warm-up and the graph optimization """ - if self.model_runner.graph_opt_level >= 1: + if self.fd_config.graph_opt_config.graph_opt_level >= 1: self.model_runner.sot_warmup() # Trigger cuda graph capture self.model_runner.capture_model() diff --git a/fastdeploy/worker/hpu_model_runner.py b/fastdeploy/worker/hpu_model_runner.py index c21006deb..4323138cf 100644 --- a/fastdeploy/worker/hpu_model_runner.py +++ b/fastdeploy/worker/hpu_model_runner.py @@ -1309,7 +1309,7 @@ class HPUModelRunner(ModelRunnerBase): accept_num=self.share_inputs["accept_num"] if self.speculative_decoding else None, ) - # if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill": + # if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": # skip_save_output = True # else: # skip_save_output = False