mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
polish code with new pre-commit rule (#2923)
This commit is contained in:
@@ -23,20 +23,22 @@ import paddle
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import \
|
||||
AttentionBackend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
||||
from fastdeploy.model_executor.ops.gpu import (draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data)
|
||||
from fastdeploy.model_executor.pre_and_post_process import (pre_process,
|
||||
rebuild_padding)
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
draft_model_update,
|
||||
eagle_get_hidden_states,
|
||||
mtp_save_first_token,
|
||||
mtp_step_paddle,
|
||||
share_external_data,
|
||||
)
|
||||
from fastdeploy.model_executor.pre_and_post_process import pre_process, rebuild_padding
|
||||
|
||||
from .base import Proposer
|
||||
|
||||
@@ -46,8 +48,7 @@ class MTPProposer(Proposer):
|
||||
Proposer for Multi-Token-Prediction(MTP)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, main_model, local_rank, device_id,
|
||||
main_model_inputs):
|
||||
def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs):
|
||||
super().__init__(cfg)
|
||||
self.num_main_model_layers = self.model_config.num_hidden_layers
|
||||
self.local_rank = local_rank
|
||||
@@ -71,12 +72,10 @@ class MTPProposer(Proposer):
|
||||
self.model_config.architectures[0] = "Ernie4_5_MTPForCausalLM"
|
||||
self.speculative_config.sharing_model = main_model
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.parallel_config.model_name_or_path = (
|
||||
self.speculative_config.model_name_or_path)
|
||||
self.parallel_config.model_name_or_path = self.speculative_config.model_name_or_path
|
||||
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
||||
if self.speculative_config.quantization != "":
|
||||
self.model_config.quantization = (
|
||||
self.speculative_config.quantization)
|
||||
self.model_config.quantization = self.speculative_config.quantization
|
||||
self.model_config.start_layer_index = self.num_main_model_layers
|
||||
self.speculative_config.model_type = "mtp"
|
||||
|
||||
@@ -84,43 +83,39 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Load MTP Layer
|
||||
"""
|
||||
from fastdeploy.model_executor.model_loader import \
|
||||
get_model_from_loader
|
||||
from fastdeploy.model_executor.model_loader import get_model_from_loader
|
||||
|
||||
self.model = get_model_from_loader(self.cfg)
|
||||
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int,
|
||||
expected_decode_len: int):
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to model_inputs"""
|
||||
max_dec_len = expected_decode_len + 1
|
||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
full_length = min(num_tokens // batch_size,
|
||||
self.parallel_config.max_model_len - max_dec_len)
|
||||
full_length = min(
|
||||
num_tokens // batch_size,
|
||||
self.parallel_config.max_model_len - max_dec_len,
|
||||
)
|
||||
input_length = int(full_length * self.parallel_config.kv_cache_ratio)
|
||||
block_num = ((input_length + self.parallel_config.block_size - 1) //
|
||||
self.parallel_config.block_size +
|
||||
self.parallel_config.enc_dec_block_num)
|
||||
block_num = (
|
||||
input_length + self.parallel_config.block_size - 1
|
||||
) // self.parallel_config.block_size + self.parallel_config.enc_dec_block_num
|
||||
|
||||
for i in range(batch_size):
|
||||
idx = i
|
||||
self.model_inputs["input_ids"][idx:idx +
|
||||
1, :input_length] = (np.array(
|
||||
[5] * input_length))
|
||||
self.model_inputs["eos_token_id"][:] = np.array(
|
||||
[2], dtype="int64").reshape(-1, 1)
|
||||
self.model_inputs["seq_lens_this_time"][idx:idx + 1] = input_length
|
||||
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = input_length
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = 0
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||
self.model_inputs["max_dec_len"][idx:idx + 1] = max_dec_len
|
||||
self.model_inputs["stop_flags"][idx:idx + 1] = False
|
||||
self.model_inputs["input_ids"][idx : idx + 1, :input_length] = np.array([5] * input_length)
|
||||
self.model_inputs["eos_token_id"][:] = np.array([2], dtype="int64").reshape(-1, 1)
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = input_length
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = input_length
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
self.model_inputs["max_dec_len"][idx : idx + 1] = max_dec_len
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
|
||||
self.model_inputs["encoder_block_lens"][idx:idx + 1] = block_num
|
||||
self.model_inputs["block_tables"][idx:idx +
|
||||
1, :block_num] = (np.arange(
|
||||
idx * block_num,
|
||||
(idx + 1) * block_num, 1))
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(
|
||||
idx * block_num, (idx + 1) * block_num, 1
|
||||
)
|
||||
|
||||
def initialize_kv_cache(self):
|
||||
"""
|
||||
@@ -131,41 +126,41 @@ class MTPProposer(Proposer):
|
||||
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
if (self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None):
|
||||
cache_type = 'uint8'
|
||||
if (
|
||||
self.quant_config
|
||||
and hasattr(self.quant_config, "kv_cache_quant_type")
|
||||
and self.quant_config.kv_cache_quant_type is not None
|
||||
):
|
||||
cache_type = "uint8"
|
||||
|
||||
# Get kv cache shape
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=self.num_gpu_blocks)
|
||||
if (not self.parallel_config.do_profile
|
||||
and (self.parallel_config.enable_prefix_caching
|
||||
or self.parallel_config.splitwise_role != "mixed")):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(max_num_blocks=self.num_gpu_blocks)
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
cache_kvs_list = []
|
||||
for i in range(
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers):
|
||||
self.num_main_model_layers,
|
||||
self.num_main_model_layers + self.model_config.num_hidden_layers,
|
||||
):
|
||||
key_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
key_cache_name = f"key_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
val_cache_name = f"value_caches_{i}_rank{self.local_rank}.device{self.device_id}"
|
||||
key_cache = share_external_data(key_cache, key_cache_name,
|
||||
kv_cache_shape)
|
||||
key_cache = share_external_data(key_cache, key_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(key_cache)
|
||||
value_cache = paddle.empty(shape=[], dtype=cache_type)
|
||||
value_cache = share_external_data(value_cache, val_cache_name,
|
||||
kv_cache_shape)
|
||||
value_cache = share_external_data(value_cache, val_cache_name, kv_cache_shape)
|
||||
cache_kvs_list.append(value_cache)
|
||||
|
||||
self.model_inputs["caches"] = cache_kvs_list
|
||||
else:
|
||||
for i in range(self.model_config.num_hidden_layers):
|
||||
self.cache_kvs["key_caches_{}".format(i)] = paddle.full(
|
||||
self.cache_kvs[f"key_caches_{i}"] = paddle.full(
|
||||
shape=kv_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
self.cache_kvs["value_caches_{}".format(i)] = paddle.full(
|
||||
self.cache_kvs[f"value_caches_{i}"] = paddle.full(
|
||||
shape=kv_cache_shape,
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
@@ -175,18 +170,19 @@ class MTPProposer(Proposer):
|
||||
del value
|
||||
paddle.device.cuda.empty_cache()
|
||||
|
||||
def _initialize_attn_backend(self, ) -> None:
|
||||
def _initialize_attn_backend(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize attention backends and forward metadata
|
||||
"""
|
||||
assert len(self.attn_backends) == 0
|
||||
|
||||
# TODO(gongshaotian): Get rank from config
|
||||
num_heads = (self.model_config.num_attention_heads //
|
||||
self.parallel_config.tensor_parallel_size)
|
||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
||||
self.model_config.kv_num_heads = (
|
||||
int(self.model_config.num_key_value_heads) //
|
||||
self.parallel_config.tensor_parallel_size)
|
||||
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Get the attention backend
|
||||
@@ -217,28 +213,25 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
|
||||
self.main_model_num_gpu_blocks = num_gpu_blocks
|
||||
self.num_gpu_blocks = int(
|
||||
num_gpu_blocks *
|
||||
self.speculative_config.num_gpu_block_expand_ratio)
|
||||
if not (self.parallel_config.enable_prefix_caching
|
||||
or self.parallel_config.splitwise_role != "mixed"):
|
||||
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
if not (self.parallel_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# Reset free list
|
||||
free_list = list(
|
||||
range(
|
||||
self.num_gpu_blocks - 1,
|
||||
int(self.main_model_num_gpu_blocks *
|
||||
self.parallel_config.kv_cache_ratio) - 1,
|
||||
int(self.main_model_num_gpu_blocks * self.parallel_config.kv_cache_ratio) - 1,
|
||||
-1,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.free_list_len = len(free_list)
|
||||
self.model_inputs.update({
|
||||
"free_list":
|
||||
paddle.to_tensor(free_list, dtype="int32"),
|
||||
"free_list_len":
|
||||
paddle.full([1], self.free_list_len, dtype="int32"),
|
||||
})
|
||||
self.model_inputs.update(
|
||||
{
|
||||
"free_list": paddle.to_tensor(free_list, dtype="int32"),
|
||||
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
||||
}
|
||||
)
|
||||
self.parallel_config.do_profile = False
|
||||
|
||||
def _init_model_inputs(self):
|
||||
@@ -247,44 +240,27 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
self.model_inputs = {}
|
||||
# Same shape/dytpe with base model
|
||||
self.model_inputs["block_tables"] = paddle.clone(
|
||||
self.main_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(
|
||||
self.main_model_inputs["input_ids"])
|
||||
self.model_inputs["seq_lens_this_time"] = paddle.clone(
|
||||
self.main_model_inputs["seq_lens_this_time"])
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(
|
||||
self.main_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(
|
||||
self.main_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(
|
||||
self.main_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(
|
||||
self.main_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(
|
||||
self.main_model_inputs["stop_nums"])
|
||||
self.model_inputs["not_need_stop"] = paddle.to_tensor([False],
|
||||
dtype="bool",
|
||||
place="cpu")
|
||||
self.model_inputs["pre_ids"] = paddle.clone(
|
||||
self.main_model_inputs["pre_ids"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(
|
||||
self.main_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["cum_offsets"] = paddle.clone(
|
||||
self.main_model_inputs["cum_offsets"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(
|
||||
self.main_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(
|
||||
self.main_model_inputs["cu_seqlens_q"])
|
||||
self.model_inputs["cu_seqlens_k"] = paddle.clone(
|
||||
self.main_model_inputs["cu_seqlens_k"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.clone(
|
||||
self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
|
||||
self.model_inputs["seq_lens_this_time"] = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"])
|
||||
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["cum_offsets"] = paddle.clone(self.main_model_inputs["cum_offsets"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"])
|
||||
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"])
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(
|
||||
self.parallel_config.max_model_len).reshape((1, -1))
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
self.model_inputs["rope_emb"] = get_rope(
|
||||
rotary_dim=self.model_config.head_dim,
|
||||
position_ids=tmp_position_ids,
|
||||
@@ -294,55 +270,41 @@ class MTPProposer(Proposer):
|
||||
# self.model_inputs["caches"] = self.cache_kvs
|
||||
# Inherit generation hyperparameters from the main model for consistency
|
||||
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
|
||||
self.model_inputs["temperature"] = self.main_model_inputs[
|
||||
"temperature"]
|
||||
self.model_inputs["eos_token_id"] = self.main_model_inputs[
|
||||
"eos_token_id"]
|
||||
self.model_inputs["penalty_score"] = self.main_model_inputs[
|
||||
"penalty_score"]
|
||||
self.model_inputs["frequency_score"] = self.main_model_inputs[
|
||||
"frequency_score"]
|
||||
self.model_inputs["presence_score"] = self.main_model_inputs[
|
||||
"presence_score"]
|
||||
self.model_inputs["temperature"] = self.main_model_inputs["temperature"]
|
||||
self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"]
|
||||
self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"]
|
||||
self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"]
|
||||
self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"]
|
||||
self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"]
|
||||
|
||||
self.model_inputs["max_dec_len"] = self.main_model_inputs[
|
||||
"max_dec_len"]
|
||||
self.model_inputs["min_dec_len"] = self.main_model_inputs[
|
||||
"min_dec_len"]
|
||||
self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"]
|
||||
self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"]
|
||||
|
||||
self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"]
|
||||
|
||||
# Integrate the updated results in model forward
|
||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs[
|
||||
"draft_tokens"]
|
||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||
self.model_inputs["substep"] = 0
|
||||
|
||||
# Input tokens
|
||||
self.model_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(
|
||||
self.main_model_inputs["encoder_block_lens"])
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||
|
||||
self.free_list = list(
|
||||
range(
|
||||
self.parallel_config.total_block_num - 1,
|
||||
int(self.parallel_config.total_block_num *
|
||||
self.parallel_config.kv_cache_ratio) - 1,
|
||||
int(self.parallel_config.total_block_num * self.parallel_config.kv_cache_ratio) - 1,
|
||||
-1,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.free_list_len = len(self.free_list)
|
||||
|
||||
self.model_inputs["free_list"] = paddle.to_tensor(self.free_list,
|
||||
dtype="int32")
|
||||
self.model_inputs["free_list_len"] = paddle.full(
|
||||
shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||
self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32")
|
||||
self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||
|
||||
self.model_inputs["batch_drop"] = paddle.full(
|
||||
shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(
|
||||
shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
"""
|
||||
@@ -368,67 +330,56 @@ class MTPProposer(Proposer):
|
||||
idx = request.idx
|
||||
length = len(request.prompt_token_ids)
|
||||
|
||||
if (req_dicts[i].disaggregate_info is not None
|
||||
and req_dicts[i].disaggregate_info["role"] == "decode"):
|
||||
if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
|
||||
length = len(request.prompt_token_ids)
|
||||
self.model_inputs["pre_ids"][idx:idx + 1] = (
|
||||
request.prompt_token_ids[-1])
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
|
||||
prefill_token_num = self.max_draft_token_num + 1
|
||||
self.model_inputs["draft_tokens"][idx : idx + 1, \
|
||||
0:1] = paddle.to_tensor(request.draft_token_ids[0:1], dtype='int64')
|
||||
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
|
||||
request.draft_token_ids[0:1], dtype="int64"
|
||||
)
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx + 1] = length
|
||||
self.model_inputs['seq_lens_this_time'][idx:idx +
|
||||
1] = prefill_token_num
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = length
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = prefill_token_num
|
||||
|
||||
self.model_inputs["stop_flags"][idx:idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx:idx + 1] = False
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 1
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 1
|
||||
encoder_block_num = len(request.block_tables)
|
||||
|
||||
self.model_inputs["encoder_block_lens"][idx:idx +
|
||||
1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx:idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][
|
||||
idx:idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32")
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
|
||||
else:
|
||||
length = len(request.prompt_token_ids)
|
||||
|
||||
if length > 1:
|
||||
self.model_inputs["input_ids"][
|
||||
idx:idx + 1, :length -
|
||||
1] = self.main_model_inputs["input_ids"][idx:idx + 1,
|
||||
1:length]
|
||||
self.model_inputs["pre_ids"][idx:idx + 1] = -1
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
if self.parallel_config.enable_chunked_prefill:
|
||||
token_chunk_size = request.prefill_chunk_info[0]
|
||||
self.model_inputs["seq_lens_encoder"][idx:idx +
|
||||
1] = token_chunk_size
|
||||
self.model_inputs["seq_lens_this_time"][
|
||||
idx:idx + 1] = token_chunk_size
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
||||
else:
|
||||
self.model_inputs["seq_lens_encoder"][idx:idx + 1] = length
|
||||
self.model_inputs["seq_lens_this_time"][idx:idx +
|
||||
1] = length
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||
1] = (request.get(
|
||||
"seq_lens_decoder",
|
||||
0))
|
||||
self.model_inputs["stop_flags"][idx:idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx:idx + 1] = False
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = request.get("seq_lens_decoder", 0)
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
|
||||
encoder_block_num = len(request.get("block_tables"))
|
||||
self.model_inputs["encoder_block_lens"][idx:idx +
|
||||
1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx:idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][
|
||||
idx:idx + 1, :encoder_block_num] = np.array(
|
||||
request.get("block_tables"), dtype="int32")
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.get("block_tables"), dtype="int32"
|
||||
)
|
||||
self.model_inputs["not_need_stop"][0] = True
|
||||
|
||||
def _initialize_forward_meta(self):
|
||||
@@ -451,10 +402,9 @@ class MTPProposer(Proposer):
|
||||
cu_seqlens_q=self.model_inputs["cu_seqlens_q"],
|
||||
cu_seqlens_k=self.model_inputs["cu_seqlens_k"],
|
||||
block_tables=self.model_inputs["block_tables"],
|
||||
caches=self.model_inputs["caches"]
|
||||
caches=self.model_inputs["caches"],
|
||||
)
|
||||
|
||||
|
||||
# Initialzie attention meta data
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
@@ -557,17 +507,14 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
)
|
||||
# Initialize forward meta data
|
||||
self.model_inputs["ids_remove_padding"].copy_(
|
||||
ids_remove_padding, False)
|
||||
self.model_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||
self.model_inputs["cum_offsets"].copy_(cum_offsets, False)
|
||||
self.model_inputs["batch_id_per_token"].copy_(
|
||||
batch_id_per_token, False)
|
||||
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
||||
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||
# for speculative decoding
|
||||
self.model_inputs["output_cum_offsets"] = output_cum_offsets
|
||||
self.model_inputs["output_padding_offset"] = (
|
||||
output_padding_offset)
|
||||
self.model_inputs["output_padding_offset"] = output_padding_offset
|
||||
self._initialize_forward_meta()
|
||||
|
||||
# Get sampling metadata
|
||||
@@ -620,37 +567,29 @@ class MTPProposer(Proposer):
|
||||
Update single task's chunk_prefill info
|
||||
"""
|
||||
idx = task.idx
|
||||
start_idx = sum(task.prefill_chunk_info[:task.chunk_idx])
|
||||
start_idx = sum(task.prefill_chunk_info[: task.chunk_idx])
|
||||
|
||||
if task.chunk_idx == len(task.prefill_chunk_info):
|
||||
self.model_inputs['seq_lens_encoder'][idx:idx + 1] = 0
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 1
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||
1] = start_idx + task.get(
|
||||
"seq_lens_decoder", 0)
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 1
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
||||
else:
|
||||
token_chunk_size = task.prefill_chunk_info[task.chunk_idx]
|
||||
|
||||
if task.chunk_idx < len(task.prefill_chunk_info) - 1:
|
||||
self.model_inputs['input_ids'][
|
||||
idx, :token_chunk_size] = np.array(
|
||||
task.prompt_token_ids[start_idx + 1:start_idx +
|
||||
token_chunk_size + 1])
|
||||
self.model_inputs["input_ids"][idx, :token_chunk_size] = np.array(
|
||||
task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size + 1]
|
||||
)
|
||||
# Last prefill
|
||||
else:
|
||||
self.model_inputs['input_ids'][
|
||||
idx, :token_chunk_size - 1] = np.array(
|
||||
task.prompt_token_ids[start_idx + 1:start_idx +
|
||||
token_chunk_size])
|
||||
self.model_inputs["input_ids"][idx, : token_chunk_size - 1] = np.array(
|
||||
task.prompt_token_ids[start_idx + 1 : start_idx + token_chunk_size]
|
||||
)
|
||||
|
||||
self.model_inputs["seq_lens_this_time"][idx:idx +
|
||||
1] = token_chunk_size
|
||||
self.model_inputs['seq_lens_encoder'][idx:idx +
|
||||
1] = token_chunk_size
|
||||
self.model_inputs["step_idx"][idx:idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx:idx +
|
||||
1] = start_idx + task.get(
|
||||
"seq_lens_decoder", 0)
|
||||
self.model_inputs["seq_lens_this_time"][idx : idx + 1] = token_chunk_size
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = token_chunk_size
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = start_idx + task.get("seq_lens_decoder", 0)
|
||||
|
||||
def _update_status(self):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user