""" # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ import os from typing import List import numpy as np import paddle from paddle import nn from paddleformers.utils.log import logger from fastdeploy import envs from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request, RequestType from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import MTPSampler from fastdeploy.model_executor.ops.gpu import ( draft_model_postprocess, draft_model_preprocess, draft_model_update, eagle_get_hidden_states, eagle_get_self_hidden_states, hybrid_mtp_ngram, mtp_save_first_token, mtp_step_paddle, share_external_data, ) from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding from .base import Proposer class MTPProposer(Proposer): """ Proposer for Multi-Token-Prediction(MTP) """ def __init__(self, cfg: FDConfig, main_model: nn.Layer, local_rank: int, device_id: int, target_model_inputs): super().__init__(cfg) self.num_main_model_layers = self.model_config.num_hidden_layers self.local_rank = local_rank self.device_id = device_id self._update_cfg(main_model) self._load_model() self.target_model_inputs = target_model_inputs self.mtp_strategy = self.speculative_config.mtp_strategy self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps # [mixed, prefill, decoder] self.role = "mixed" self.sampler = MTPSampler(cfg) self._init_model_inputs() self.attn_backends: list[AttentionBackend] = [] self._initialize_attn_backend() def _update_cfg(self, main_model): """ Update config for MTP from global config """ self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM" self.speculative_config.sharing_model = main_model self.model_config.num_hidden_layers = 1 self.model_config.model = self.speculative_config.model self.model_config.pretrained_config.prefix_name = "ernie.mtp_block" if self.speculative_config.quantization != "": self.model_config.quantization = self.speculative_config.quantization self.model_config.start_layer_index = self.num_main_model_layers self.speculative_config.model_type = "mtp" def _load_model(self): """ Load MTP Layer """ from fastdeploy.model_executor.model_loader import get_model_loader model_loader = get_model_loader(load_config=self.cfg.load_config) self.model = model_loader.load_model(fd_config=self.cfg) def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): """Set dummy prefill inputs to model_inputs""" max_dec_len = expected_decode_len + 1 self.num_gpu_blocks = self.parallel_config.total_block_num self.initialize_kv_cache() full_length = min( num_tokens // batch_size, self.parallel_config.max_model_len - max_dec_len, ) input_length = int(full_length * self.cache_config.kv_cache_ratio) block_num = ( input_length + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num for i in range(batch_size): idx = i self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length) self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1) self.seq_lens_this_time_buffer[idx : idx + 1] = input_length self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["max_dec_len"][idx : idx + 1] = max_dec_len self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["encoder_block_lens"][idx : idx + 1] = block_num self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange( idx * block_num, (idx + 1) * block_num, 1 ) self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer def initialize_kv_cache(self): """ Initialize kv cache """ # prompt cache self.cache_kvs = {} cache_type = self.parallel_config.dtype kv_cache_quant_type = None if ( self.quant_config and hasattr(self.quant_config, "kv_cache_quant_type") and self.quant_config.kv_cache_quant_type is not None ): cache_type = "uint8" kv_cache_quant_type = self.quant_config.kv_cache_quant_type # Get kv cache shape kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type ) if not self.parallel_config.do_profile and ( self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed" ): cache_kvs_list = [] for i in range( self.num_main_model_layers, self.num_main_model_layers + self.model_config.num_hidden_layers, ): key_cache = paddle.empty(shape=[], dtype=cache_type) key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}" val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}" key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape) cache_kvs_list.append(key_cache) value_cache = paddle.empty(shape=[], dtype=cache_type) value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape) cache_kvs_list.append(value_cache) self.model_inputs["caches"] = cache_kvs_list else: for i in range(self.model_config.num_hidden_layers): self.cache_kvs[f"key_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, dtype=cache_type, ) self.cache_kvs[f"value_caches_{i}"] = paddle.full( shape=kv_cache_shape, fill_value=0, dtype=cache_type, ) self.model_inputs["caches"] = list(self.cache_kvs.values()) for value in self.cache_kvs.values(): del value paddle.device.cuda.empty_cache() def _initialize_attn_backend( self, ) -> None: """ Initialize attention backends and forward metadata """ assert len(self.attn_backends) == 0 num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size self.model_config.kv_num_heads = max( 1, int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size, ) head_dim = self.model_config.head_dim # Initialize AttentionBackend buffers encoder_block_shape_q = 64 decoder_block_shape_q = 16 self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["decoder_tile_ids_per_batch"] ) self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_cpu"] ).pin_memory() self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like( self.target_model_inputs["decoder_num_blocks_device"] ) self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like( self.target_model_inputs["decoder_chunk_size_device"] ) self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( self.target_model_inputs["max_len_tensor_cpu"] ).cpu() self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"]) self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["encoder_tile_ids_per_batch"] ) self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like( self.target_model_inputs["encoder_num_blocks_x_cpu"] ).cpu() self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"]) self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like( self.target_model_inputs["kv_tile_ids_per_batch"] ) self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like( self.target_model_inputs["kv_num_blocks_x_cpu"] ).cpu() self.model_inputs["max_len_kv_cpu"] = paddle.zeros_like(self.target_model_inputs["max_len_kv_cpu"]).cpu() # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( self.cfg, kv_num_heads=self.model_config.kv_num_heads, num_heads=num_heads, head_dim=head_dim, encoder_block_shape_q=encoder_block_shape_q, decoder_block_shape_q=decoder_block_shape_q, ) if attn_backend is None: raise NotImplementedError( "Attention backend which you specified is not supported, please set FD_ATTENTION_BACKEND correctly." ) self.attn_backends.append(attn_backend) def clear_dummy_input(self): """ Clear allocated cacheKV """ del self.model_inputs["caches"] if self.forward_meta is not None: del self.forward_meta.caches def update_block_num(self, num_gpu_blocks) -> None: """ Update block num by theoretical calculation """ self.main_model_num_gpu_blocks = num_gpu_blocks self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) if not (self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"): self.initialize_kv_cache() # Reset free list free_list = list( range( self.num_gpu_blocks - 1, int(self.main_model_num_gpu_blocks * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(free_list) self.model_inputs.update( { "free_list": paddle.to_tensor(free_list, dtype="int32"), "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), } ) self.parallel_config.do_profile = False def _init_model_inputs(self): """ Init model inputs """ self.model_inputs = {} # Same shape/dytpe with base model self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"]) self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"]) self.model_inputs["input_ids_cpu"] = paddle.full( shape=[self.max_num_seqs, self.parallel_config.max_model_len], fill_value=-1, dtype="int64", ).cpu() self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"]) self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"]) self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"]) self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"]) self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"]) self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"]) self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"]) self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( self.target_model_inputs["decoder_tile_ids_per_batch"] ) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) self.model_inputs["rope_emb"] = get_rope( rotary_dim=self.model_config.head_dim, position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, ) # self.model_inputs["caches"] = self.cache_kvs # Inherit generation hyperparameters from the main model for consistency self.model_inputs["top_p"] = self.target_model_inputs["top_p"] self.model_inputs["top_k"] = self.target_model_inputs["top_k"] self.model_inputs["temperature"] = self.target_model_inputs["temperature"] self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"] self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"] self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"] self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"] self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"] self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"] self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"] self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"] # Integrate the updated results in model forward self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"] self.model_inputs["substep"] = 0 # Declare AttentionBackend buffers self.model_inputs["decoder_batch_ids"] = None self.model_inputs["decoder_tile_ids_per_batch"] = None self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory self.model_inputs["decoder_num_blocks_device"] = None self.model_inputs["decoder_chunk_size_device"] = None self.model_inputs["max_len_tensor_cpu"] = None # CPU self.model_inputs["encoder_batch_ids"] = None self.model_inputs["encoder_tile_ids_per_batch"] = None self.model_inputs["encoder_num_blocks_x_cpu"] = None # CPU self.model_inputs["kv_batch_ids"] = None self.model_inputs["kv_tile_ids_per_batch"] = None self.model_inputs["kv_num_blocks_x_cpu"] = None # CPU self.model_inputs["max_len_kv_cpu"] = None # CPU # Input tokens self.model_inputs["draft_tokens"] = paddle.full( shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64" ) self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"]) self.free_list = list( range( self.parallel_config.total_block_num - 1, int(self.parallel_config.total_block_num * self.cache_config.kv_cache_ratio) - 1, -1, ) ) self.free_list_len = len(self.free_list) self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32") self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32") self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool") self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32") if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.full_like( self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32" ) self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): if "caches" not in self.model_inputs: self.initialize_kv_cache() req_len = len(req_dicts) # has_prefill_task = False # has_decode_task = False for i in range(req_len): request = req_dicts[i] logger.info(f"{i}th request-{request.request_id}: {request}") idx = request.idx if request.task_type.value == RequestType.PREFILL.value: # prefill task prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index input_ids = request.prompt_token_ids + request.output_token_ids self.input_ids_len[idx] = length - 1 self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][ idx : idx + 1, 1:length ] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length].cpu() encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index self.seq_lens_this_time_buffer[idx : idx + 1] = length self.model_inputs["step_idx"][idx : idx + 1] = ( len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) # has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) # if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode # has_decode_task = True # continue else: self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["stop_flags"][idx : idx + 1] = True self.seq_lens_this_time_buffer[idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["is_block_step"][idx : idx + 1] = False continue # if has_prefill_task or has_decode_task: # self.model_inputs["not_need_stop"][0] = True self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): """ Process inputs for prefill tasks and insert it to model_inputs buffer """ # NOTE: Lazy initialize kv cache if "caches" not in self.model_inputs: self.initialize_kv_cache() # TODO:Init role in initialize process if req_dicts[-1].disaggregate_info is not None: if req_dicts[-1].disaggregate_info["role"] == "prefill": self.role = "prefill" os.environ["PREFILL_NODE_ONE_STEP_STOP"] = "1" elif req_dicts[-1].disaggregate_info["role"] == "decode": self.role = "decode" else: self.role = "mixed" req_len = len(req_dicts) for i in range(req_len): request = req_dicts[i] idx = request.idx length = len(request.prompt_token_ids) self.input_ids_len[idx] = length - 1 if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode": length = len(request.prompt_token_ids) if length > 1: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1] prefill_token_num = self.max_draft_token_num + 1 self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor( request.draft_token_ids[1:2], dtype="int64" ) self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length self.seq_lens_this_time_buffer[idx : idx + 1] = prefill_token_num self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False self.model_inputs["step_idx"][idx : idx + 1] = 1 encoder_block_num = len(request.block_tables) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.block_tables, dtype="int32" ) else: length = len(request.prompt_token_ids) if length > 1: self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[ "input_ids" ][idx : idx + 1, 1:length] self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array( request.prompt_token_ids )[1:] self.model_inputs["pre_ids"][idx : idx + 1] = -1 self.model_inputs["step_idx"][idx : idx + 1] = 0 if self.cache_config.enable_chunked_prefill: token_chunk_size = request.prefill_chunk_info[0] self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.seq_lens_this_time_buffer[idx : idx + 1] = token_chunk_size else: self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length self.seq_lens_this_time_buffer[idx : idx + 1] = length self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0) self.model_inputs["stop_flags"][idx : idx + 1] = False self.model_inputs["batch_drop"][idx : idx + 1] = False encoder_block_num = len(request.get("block_tables")) self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num self.model_inputs["block_tables"][idx : idx + 1, :] = -1 self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array( request.get("block_tables"), dtype="int32" ) self.model_inputs["not_need_stop"][0] = True self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] def _initialize_forward_meta(self): """ Initialize forward meta and attention meta data """ # Initialize forward meta self.forward_meta = ForwardMeta( input_ids=self.model_inputs["input_ids"], ids_remove_padding=self.model_inputs["ids_remove_padding"], rotary_embs=self.model_inputs["rope_emb"], attn_backend=self.attn_backends[0], decoder_batch_ids=self.model_inputs["decoder_batch_ids"], decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"], decoder_num_blocks_device=self.model_inputs["decoder_num_blocks_device"], decoder_chunk_size_device=self.model_inputs["decoder_chunk_size_device"], max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"], seq_lens_encoder=self.model_inputs["seq_lens_encoder"], seq_lens_decoder=self.model_inputs["seq_lens_decoder"], seq_lens_this_time=self.model_inputs["seq_lens_this_time"], batch_id_per_token=self.model_inputs["batch_id_per_token"], cu_seqlens_q=self.model_inputs["cu_seqlens_q"], cu_seqlens_k=self.model_inputs["cu_seqlens_k"], block_tables=self.model_inputs["block_tables"], caches=self.model_inputs["caches"], encoder_batch_ids=self.model_inputs["encoder_batch_ids"], encoder_tile_ids_per_batch=self.model_inputs["encoder_tile_ids_per_batch"], encoder_num_blocks_x_cpu=self.model_inputs["encoder_num_blocks_x_cpu"], kv_batch_ids=self.model_inputs["kv_batch_ids"], kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], max_len_kv_cpu=self.model_inputs["max_len_kv_cpu"], ) # Initialzie attention meta data for attn_backend in self.attn_backends: attn_backend.init_attention_metadata(self.forward_meta) def _prepare_inputs(self, full_hidden_states): """ Prepare MTP inputs """ use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER draft_model_preprocess( self.model_inputs["draft_tokens"], self.model_inputs["input_ids"], self.model_inputs["stop_flags"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], self.model_inputs["not_need_stop"], self.model_inputs["batch_drop"], self.model_inputs["is_block_step"], self.model_inputs["pre_ids"], self.target_model_inputs["accept_tokens"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_decoder"], self.target_model_inputs["step_idx"], self.target_model_inputs["stop_flags"], self.target_model_inputs["is_block_step"], self.target_model_inputs["draft_tokens"], self.num_model_steps, self.speculative_method in ["eagle", "mtp"], self.role == "prefill", use_v1_cache_scheduler, ) target_hidden_states = eagle_get_hidden_states( full_hidden_states, self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["stop_flags"], self.target_model_inputs["accept_num"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.num_model_steps, ) if isinstance(target_hidden_states, list): target_hidden_states = target_hidden_states[0] return target_hidden_states def _post_process(self, sampled_token_ids): """ PostProcess for generation """ draft_model_update( sampled_token_ids, self.model_inputs["draft_tokens"], self.model_inputs["pre_ids"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["step_idx"], self.model_inputs["output_cum_offsets"], self.model_inputs["stop_flags"], self.model_inputs["not_need_stop"], self.model_inputs["max_dec_len"], self.model_inputs["eos_token_id"], self.model_inputs["base_model_draft_tokens"], self.max_model_len, self.model_inputs["substep"], ) if self.role == "prefill": mtp_save_first_token( self.model_inputs["base_model_draft_tokens"], self.model_inputs["not_need_stop"], self.local_rank, self.parallel_config.use_ep, ) def _propose(self, target_hidden_states): """ Main process for MTP inference """ for substep in range(self.num_model_steps): if self.model_inputs["not_need_stop"]: self.model_inputs["substep"] = substep # Remove padding ( ids_remove_padding, batch_id_per_token, cu_seqlens_q, cu_seqlens_k, output_cum_offsets, output_padding_offset, ) = pre_process( self.model_inputs["input_ids"], self.model_inputs["seq_lens_this_time"], True, self.model_inputs["draft_tokens"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], ) # Initialize forward meta data self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False) self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) # for speculative decoding self.model_inputs["output_cum_offsets"] = output_cum_offsets self.model_inputs["output_padding_offset"] = output_padding_offset self._initialize_forward_meta() # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.model_inputs["temperature"], top_p=self.model_inputs["top_p"], top_k=self.model_inputs["top_k"], step_idx=self.model_inputs["step_idx"], pre_token_ids=self.model_inputs["pre_ids"], frequency_penalties=self.model_inputs["frequency_score"], presence_penalties=self.model_inputs["presence_score"], repetition_penalties=self.model_inputs["penalty_score"], min_dec_lens=self.model_inputs["min_dec_len"], bad_words_token_ids=self.model_inputs["bad_tokens"], eos_token_ids=self.model_inputs["eos_token_id"], ) if self.num_model_steps > 1: self.last_seq_lens_this_time = paddle.clone(self.model_inputs["seq_lens_this_time"]) model_output = self.model( ids_remove_padding=self.model_inputs["ids_remove_padding"], previous_hidden_states=target_hidden_states, forward_meta=self.forward_meta, ) hidden_states = rebuild_padding( model_output, self.model_inputs["cu_seqlens_q"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_decoder"], self.model_inputs["seq_lens_encoder"], self.model_inputs["output_padding_offset"], self.parallel_config.max_model_len, ) # 4. Compute logits, Sample logits = self.model.compute_logits(hidden_states) sampled_token_ids = self.sampler( logits, self.sampling_metadata, self.max_model_len, self.model_inputs, ) if self.parallel_config.tensor_parallel_size > 1: paddle.distributed.broadcast(sampled_token_ids, 0) self._post_process(sampled_token_ids) if substep != self.num_model_steps - 1: target_hidden_states = self._get_self_hidden_states(hidden_states) def _get_self_hidden_states(self, hidden_states): target_hidden_states = eagle_get_self_hidden_states( hidden_states, self.last_seq_lens_this_time, self.model_inputs["seq_lens_this_time"], self.model_inputs["step_idx"], ) if isinstance(target_hidden_states, list): target_hidden_states = target_hidden_states[0] return target_hidden_states def update_task_chunk_prefill(self, task): """ Update single task's chunk_prefill info """ idx = task.idx start_idx = sum(task.prefill_chunk_info[: task.chunk_idx]) if task.chunk_idx == len(task.prefill_chunk_info): self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.model_inputs["step_idx"][idx : idx + 1] = 1 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) else: token_chunk_size = task.prefill_chunk_info[task.chunk_idx] if task.chunk_idx < len(task.prefill_chunk_info) - 1: self.model_inputs["input_ids"][idx, :token_chunk_size] = np.array( task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size + 1] ) # Last prefill else: self.model_inputs["input_ids"][idx, : token_chunk_size - 1] = np.array( task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size] ) self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size self.model_inputs["step_idx"][idx : idx + 1] = 0 self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0) def _update_status(self): """ Update main-model's forward info in next step. Allocate/Free block of MPT. """ draft_model_postprocess( self.target_model_inputs["draft_tokens"], self.target_model_inputs["seq_lens_this_time"], self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["stop_flags"], ) if not envs.ENABLE_V1_KVCACHE_SCHEDULER: mtp_step_paddle( self.target_model_inputs["stop_flags"], self.model_inputs["stop_flags"], self.model_inputs["batch_drop"], self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_encoder"], self.model_inputs["seq_lens_decoder"], self.model_inputs["block_tables"], self.model_inputs["encoder_block_lens"], self.model_inputs["used_list_len"], self.model_inputs["free_list"], self.model_inputs["free_list_len"], self.cache_config.block_size, self.max_draft_token_num, ) def _extend_draft_token_with_ngram_match(self): # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency device = paddle.CUDAPinnedPlace() draft_tokens = self.target_model_inputs["draft_tokens"].cpu() seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() hybrid_mtp_ngram( self.model_inputs["input_ids_cpu"], self.input_ids_len, self.model_inputs["pre_ids"]._copy_to(device, True), self.model_inputs["step_idx"].cpu(), self.target_model_inputs["actual_draft_token_num"].cpu(), draft_tokens, seq_lens_this_time, seq_lens_decoder, self.model_inputs["max_dec_len"].cpu(), self.max_ngram_size, self.min_ngram_size, self.max_draft_token_num, ) self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() def _run_impl(self, full_hidden_states): """""" target_hidden_states = self._prepare_inputs(full_hidden_states) self._propose(target_hidden_states=target_hidden_states) self._update_status() if self.hybrid_mode: self._extend_draft_token_with_ngram_match() def is_chunk_prefill_enabled(self): """""" return True