""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import os from typing import List import numpy as np import paddle from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import MTPSampler from fastdeploy.model_executor.model_loader import get_model_loader from fastdeploy.model_executor.models import ModelForCasualLM from fastdeploy.model_executor.ops.gpu import ( draft_model_postprocess, draft_model_preprocess, draft_model_update, eagle_get_hidden_states, eagle_get_self_hidden_states, hybrid_mtp_ngram, mtp_save_first_token, mtp_step_paddle, share_external_data, speculate_get_logits, speculate_save_output_topk, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding from .base import Proposer class MTPProposer(Proposer): """ Proposer for Multi-Token-Prediction(MTP) """ def __init__( self, fd_config: FDConfig, main_model: ModelForCasualLM, local_rank: int, device_id: int, # physical device id target_model_inputs, # main model share inputs ): super().__init__(fd_config) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id self._update_mtp_config(main_model) self._load_model() self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps self.enable_logprob = self.model_config.enable_logprob # [mixed, prefill, decoder] self.role = "mixed" self.sampler = MTPSampler(fd_config) self._init_model_inputs() # CUDA Graph self.use_cudagraph = False # TODO(gongshaotian): Use Target Model flag self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes self.attn_backends: list[AttentionBackend] = [] self._initialize_attn_backend() def _update_mtp_config(self, main_model): """ Update config for MTP from global config """ self.forward_meta: ForwardMeta = None self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP") self.speculative_config.sharing_model = main_model self.model_config.num_hidden_layers = 1 self.model_config.model = self.speculative_config.model self.model_config.pretrained_config.prefix_name = "ernie.mtp_block" if self.speculative_config.quantization != "": self.model_config.quantization = self.speculative_config.quantization self.model_config.start_layer_index = self.num_main_model_layers self.speculative_config.model_type = "mtp" def _load_model(self): """ Load MTP Layer """ model_loader = get_model_loader(load_config=self.fd_config.load_config) self.model = model_loader.load_model(fd_config=self.fd_config) def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """Set dummy prefill inputs to model_inputs""" max_dec_len = expected_decode_len + 1 input_length = min( num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len, ) if self.fd_config.parallel_config.enable_expert_parallel: input_length = min(input_length, 32) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num for i in range(batch_size): idx = i self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) self.seq_lens_this_time_buffer[idx : idx + 1] = input_length self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["max_dec_len"][idx : idx + 1] = max_dec_len self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["encoder_block_lens"][idx : idx + 1] = block_num self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( idx * block_num, (idx + 1) * block_num, 1 ) self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): """ Initialize kv cache """ self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.cache_kvs = {} # Get kv cache dtype cache_type = self.parallel_config.dtype kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] local_rank = self.local_rank % self.parallel_config.tensor_parallel_size if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"): cache_kvs_list = [] for i in range( self.num_main_model_layers, self.num_main_model_layers + self.model_config.num_hidden_layers, ): key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache_name = f"key_caches_{i}_rank{local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{local_rank}.device{self.device_id}" key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) cache_kvs_list.append(key_cache) value_cache = paddle.empty(shape=[], dtype=cache_type) value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape) cache_kvs_list.append(value_cache) self.model_inputs["caches"] = cache_kvs_list else: for i in range(self.model_config.num_hidden_layers): self.cache_kvs[f"key_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, dtype=cache_type, ) self.cache_kvs[f"value_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, dtype=cache_type, ) if kv_cache_quant_type == "block_wise_fp8": self.cache_kvs[f"key_cache_scales_{i}"] = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype(), ) self.cache_kvs[f"value_cache_scales_{i}"] = paddle.full( shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype(), ) self.model_inputs["caches"] = list(self.cache_kvs.values()) for value in self.cache_kvs.values(): del value paddle.device.cuda.empty_cache() def _initialize_attn_backend( self, ) -> None: """ Initialize attention backends and forward metadata """ assert len(self.attn_backends) == 0 num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size self.model_config.kv_num_heads = max( 1, int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size, ) head_dim = self.model_config.head_dim # Initialize AttentionBackend buffers encoder_block_shape_q = 64 decoder_block_shape_q = 16 self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_cpu"] ).pin_memory() self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( self.target_model_inputs["max_len_tensor_cpu"] ).cpu() # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( self.fd_config, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, ) if attn_backend is None: raise NotImplementedError( "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." ) self.attn_backends.append(attn_backend) def clear_mtp_cache(self): """ Clear allocated cacheKV """ del self.model_inputs["caches"] if self.forward_meta is not None: del self.forward_meta.caches def update_mtp_block_num(self, num_gpu_blocks) -> None: """ Update MTP block num by theoretical calculation """ # Reset block table and kv cache with global block num self.main_model_num_gpu_blocks = num_gpu_blocks self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks) # Reset free list free_list = list( range( self.num_gpu_blocks - 1, int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(free_list) self.model_inputs.update( { "free_list": paddle.to_tensor(free_list, dtype="int32"), "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), } ) def _init_model_inputs(self): """ Init model inputs """ self.model_inputs = {} # Same shape/dytpe with base model self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"]) self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"]) self.model_inputs["input_ids_cpu"] = paddle.full( shape=[self.max_num_seqs, self.parallel_config.max_model_len], fill_value=-1, dtype="int64", ).cpu() self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"]) self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"]) self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"]) self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"]) self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"]) self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"]) self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["target_hidden_states"] = paddle.full( [self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16" ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) self.model_inputs["rope_emb"] = get_rope( rotary_dim=self.model_config.head_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, ) # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency self.model_inputs["top_p"] = ( self.target_model_inputs["top_p"] if envs.FD_SPECULATE_SAMPLING_TOP_P is None else paddle.full_like(self.target_model_inputs["top_p"], envs.FD_SPECULATE_SAMPLING_TOP_P) ) self.model_inputs["top_k"] = ( self.target_model_inputs["top_k"] if envs.FD_SPECULATE_SAMPLING_TOP_K is None else paddle.full_like(self.target_model_inputs["top_k"], envs.FD_SPECULATE_SAMPLING_TOP_K) ) self.model_inputs["temperature"] = self.target_model_inputs["temperature"] self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"] self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"] self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"] self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"] self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"] self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"] self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"] self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"] # Integrate the updated results in model forward self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"] self.model_inputs["substep"] = 0 # Declare AttentionBackend buffers self.model_inputs["decoder_batch_ids"] = None self.model_inputs["decoder_tile_ids_per_batch"] = None self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.model_inputs["max_len_tensor_cpu"] = None # CPU # Input tokens self.model_inputs["draft_tokens"] = paddle.full( shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64" ) self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"]) self.free_list = list( range( self.parallel_config.total_block_num - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(self.free_list) self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32") self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32") self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.full_like( self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() self.model_inputs["temp_scaled_logprobs"] = self.target_model_inputs["temp_scaled_logprobs"] self.model_inputs["top_p_normalized_logprobs"] = self.target_model_inputs["top_p_normalized_logprobs"] self.model_inputs["accept_num"] = self.target_model_inputs["accept_num"] self.model_inputs["accept_tokens"] = self.target_model_inputs["accept_tokens"] self.model_inputs["draft_logits"] = self.target_model_inputs["draft_logits"] self.model_inputs["first_token_hidden_states"] = paddle.full( [self.max_num_seqs, self.model_config.hidden_size], -1 ) self.model_inputs["batch_token_num"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") self.model_inputs["next_token_num"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") self.model_inputs["cu_batch_token_offset"] = paddle.full_like( self.target_model_inputs["cu_batch_token_offset"], fill_value=0, dtype="int32" ) self.model_inputs["cu_next_token_offset"] = paddle.full( shape=[self.max_num_seqs + 1], fill_value=0, dtype="int32" ) def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): if "caches" not in self.model_inputs: self.initialize_kv_cache() req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] logger.debug(f"{i}th request-{request.request_id}: {request}") idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index input_ids = request.prompt_token_ids + request.output_token_ids self.input_ids_len[idx] = length self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ idx : idx + 1, 1:length ] encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.seq_lens_this_time_buffer[idx : idx + 1] = length self.model_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode # has_decode_task = True # continue else: self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["stop_flags"][idx : idx + 1] = True self.seq_lens_this_time_buffer[idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["is_block_step"][idx : idx + 1] = False continue # if has_prefill_task or has_decode_task: # self.model_inputs["not_need_stop"][0] = True self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """ Process inputs for prefill tasks and insert it to model_inputs buffer """ # TODO:Init role in initialize process if req_dicts[-1].disaggregate_info is not None: if req_dicts[-1].disaggregate_info["role"] == "prefill": self.role = "prefill" os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" elif req_dicts[-1].disaggregate_info["role"] == "decode": self.role = "decode" else: self.role = "mixed" req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx length = len(request.prompt_token_ids) self.input_ids_len[idx] = length - 1 if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": length = len(request.prompt_token_ids) if length > 1: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] prefill_token_num = self.max_draft_token_num + 1 self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor( request.draft_token_ids[1:2], dtype="int64" ) self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length self.seq_lens_this_time_buffer[idx : idx + 1] = prefill_token_num self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False self.model_inputs["step_idx"][idx : idx + 1] = 1 encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) else: length = len(request.prompt_token_ids) if length > 1: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["step_idx"][idx : idx + 1] = 0 if self.cache_config.enable_chunked_prefill: token_chunk_size = request.prefill_chunk_info[0] self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size else: self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length self.seq_lens_this_time_buffer[idx : idx + 1] = length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False encoder_block_num = len(request.get("block_tables")) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.get("block_tables"), dtype="int32" ) self.model_inputs["not_need_stop"][0] = True self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer def _initialize_forward_meta(self, step_use_cudagraph: bool = False): """ Initialize forward meta and attention meta data """ # Initialize forward meta self.forward_meta = ForwardMeta( input_ids=self.model_inputs["input_ids"], ids_remove_padding=self.model_inputs["ids_remove_padding"], rotary_embs=self.model_inputs["rope_emb"], attn_backend=self.attn_backends[0], decoder_batch_ids=self.model_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.model_inputs["seq_lens_encoder"], seq_lens_decoder=self.model_inputs["seq_lens_decoder"], seq_lens_this_time=self.model_inputs["seq_lens_this_time"], batch_id_per_token=self.model_inputs["batch_id_per_token"], cu_seqlens_q=self.model_inputs["cu_seqlens_q"], cu_seqlens_k=self.model_inputs["cu_seqlens_k"], block_tables=self.model_inputs["block_tables"], caches=self.model_inputs["caches"], ) # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) # TODO(gongshaotian): Use CUDAGraph with Draft Model self.forward_meta.step_use_cudagraph = step_use_cudagraph and self.use_cudagraph def exist_prefill(self): """ check whether prefill stage exist """ if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0: return 1 else: return 0 def _prepare_inputs(self, full_hidden_states): """ Prepare MTP inputs """ use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], self.model_inputs["stop_flags"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], self.model_inputs["not_need_stop"], self.model_inputs["batch_drop"], self.model_inputs["is_block_step"], self.model_inputs["pre_ids"], self.target_model_inputs["accept_tokens"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_decoder"], self.target_model_inputs["step_idx"], self.target_model_inputs["stop_flags"], self.target_model_inputs["is_block_step"], self.target_model_inputs["draft_tokens"], self.num_model_steps, self.speculative_method in ["eagle", "mtp"], self.role == "prefill", use_v1_cache_scheduler, ) target_hidden_states = eagle_get_hidden_states( full_hidden_states, self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["stop_flags"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def _post_process(self, sampled_token_ids): """ PostProcess for generation """ draft_model_update( sampled_token_ids, self.model_inputs["draft_tokens"], self.model_inputs["pre_ids"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], self.model_inputs["output_cum_offsets"], self.model_inputs["stop_flags"], self.model_inputs["not_need_stop"], self.model_inputs["max_dec_len"], self.model_inputs["eos_token_id"], self.model_inputs["base_model_draft_tokens"], self.max_model_len, self.model_inputs["substep"], ) if self.role == "prefill" and self.parallel_config.tensor_parallel_rank == 0: mtp_save_first_token( self.model_inputs["base_model_draft_tokens"], self.model_inputs["not_need_stop"], self.local_rank, self.parallel_config.use_ep, ) def _propose(self, step_use_cudagraph: bool = False): """ Main process for MTP inference. Args: step_use_cudagraph: bool Whether to use cuda graph. Use the target model flag to avoid hanging problems with EP. """ for substep in range(self.num_model_steps): if self.model_inputs["not_need_stop"]: self.model_inputs["substep"] = substep # Remove padding ( ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, output_cum_offsets, output_padding_offset, ) = pre_process( self.model_inputs["input_ids"], self.model_inputs["seq_lens_this_time"], True, self.model_inputs["draft_tokens"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], ) # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.model_inputs["batch_id_per_token"][:] = -1 self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # for speculative decoding self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False) self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False) # Initialize forward meta data self._initialize_forward_meta(step_use_cudagraph=step_use_cudagraph) self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False) # Padding inputs for cuda graph self.padding_cudagraph_inputs() # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], top_p=self.model_inputs["top_p"], top_k=self.model_inputs["top_k"], step_idx=self.model_inputs["step_idx"], pre_token_ids=self.model_inputs["pre_ids"], frequency_penalties=self.model_inputs["frequency_score"], presence_penalties=self.model_inputs["presence_score"], repetition_penalties=self.model_inputs["penalty_score"], min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], eos_token_ids=self.model_inputs["eos_token_id"], max_num_logprobs=20 if self.enable_logprob else None, temp_scaled_logprobs=self.model_inputs["temp_scaled_logprobs"], top_p_normalized_logprobs=self.model_inputs["top_p_normalized_logprobs"], share_inputs=self.model_inputs, ) if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], previous_hidden_states=self.model_inputs["target_hidden_states"], forward_meta=self.forward_meta, ) if self.forward_meta.step_use_cudagraph: model_output = model_output[: self.real_token_num] hidden_states = rebuild_padding( model_output, self.model_inputs["cu_seqlens_q"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_encoder"], self.model_inputs["output_padding_offset"], self.parallel_config.max_model_len, self.model_inputs["first_token_hidden_states"], self.enable_logprob if substep == 0 else False, ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) if self.enable_logprob and substep == 0: first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"]) speculate_get_logits( self.model_inputs["draft_logits"], self.model_inputs["next_token_num"], self.model_inputs["batch_token_num"], self.model_inputs["cu_next_token_offset"], self.model_inputs["cu_batch_token_offset"], logits, first_token_logits, self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], ) sampled_token_ids, sampler_output = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) if substep == 0 and sampler_output.logprobs_tensors is not None: real_bsz = self.model_inputs["seq_lens_this_time"].shape[0] speculate_save_output_topk( sampler_output.sampled_token_ids, sampler_output.logprobs_tensors.logprob_token_ids, sampler_output.logprobs_tensors.logprobs, sampler_output.logprobs_tensors.selected_token_ranks, self.model_inputs["batch_token_num"][:real_bsz], self.model_inputs["cu_batch_token_offset"][:real_bsz], self.model_inputs["not_need_stop"], 4, # mtype self.local_rank, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast( sampled_token_ids, self.parallel_config.data_parallel_rank * self.parallel_config.tensor_parallel_size, group=self.parallel_config.tp_group, ) self._post_process(sampled_token_ids) if substep != self.num_model_steps - 1: self._get_self_hidden_states(hidden_states) else: if hasattr(self.model, "empty_input_forward"): self.model.empty_input_forward() def _get_self_hidden_states(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states( hidden_states, self.last_seq_lens_this_time, self.model_inputs["seq_lens_this_time"], self.model_inputs["step_idx"], ) self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False) def update_task_chunk_prefill(self, task): """ Update single task's chunk_prefill info """ idx = task.idx start_idx = sum(task.prefill_chunk_info[: task.chunk_idx]) if task.chunk_idx == len(task.prefill_chunk_info): self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 1 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) else: token_chunk_size = task.prefill_chunk_info[task.chunk_idx] if task.chunk_idx < len(task.prefill_chunk_info) - 1: self.model_inputs["input_ids"][idx, :token_chunk_size] = np.array( task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size + 1] ) # Last prefill else: self.model_inputs["input_ids"][idx, : token_chunk_size - 1] = np.array( task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size] ) self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) def _update_status(self): """ Update main-model's forward info in next step. Allocate/Free block of MPT. """ draft_model_postprocess( self.target_model_inputs["draft_tokens"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["stop_flags"], ) if not envs.ENABLE_V1_KVCACHE_SCHEDULER: mtp_step_paddle( self.target_model_inputs["stop_flags"], self.model_inputs["stop_flags"], self.model_inputs["batch_drop"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["block_tables"], self.model_inputs["encoder_block_lens"], self.model_inputs["used_list_len"], self.model_inputs["free_list"], self.model_inputs["free_list_len"], self.cache_config.block_size, self.max_draft_token_num, ) def _extend_draft_token_with_ngram_match(self): # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency device = paddle.CUDAPinnedPlace() draft_tokens = self.target_model_inputs["draft_tokens"].cpu() seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() hybrid_mtp_ngram( self.model_inputs["input_ids_cpu"], self.input_ids_len, self.model_inputs["pre_ids"]._copy_to(device, True), self.model_inputs["step_idx"].cpu(), self.target_model_inputs["actual_draft_token_num"].cpu(), draft_tokens, seq_lens_this_time, seq_lens_decoder, self.model_inputs["max_dec_len"].cpu(), self.max_ngram_size, self.min_ngram_size, self.max_draft_token_num, ) self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() def _run_impl(self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False): """Execute Draft Model""" self._prepare_inputs(full_hidden_states) self._propose(step_use_cudagraph=step_use_cudagraph) self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() def is_chunk_prefill_enabled(self): """""" return True def padding_cudagraph_inputs(self) -> None: """ Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. """ # In init_attention_metadata, the decode buffer has already been cleared # To adapt to CUDA Graph, keep the forward pass at the maximum batch size. if self.forward_meta.step_use_cudagraph: self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer self.real_token_num = self.forward_meta.ids_remove_padding.shape[0] return