mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
support mtp in v1_scheduler mode (#3695)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -19,8 +19,10 @@ from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
@@ -50,14 +52,14 @@ class MTPProposer(Proposer):
|
||||
Proposer for Multi-Token-Prediction(MTP)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, main_model, local_rank, device_id, main_model_inputs):
|
||||
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
|
||||
super().__init__(cfg)
|
||||
self.num_main_model_layers = self.model_config.num_hidden_layers
|
||||
self.local_rank = local_rank
|
||||
self.device_id = device_id
|
||||
self._update_cfg(main_model)
|
||||
self._load_model()
|
||||
self.main_model_inputs = main_model_inputs
|
||||
self.target_model_inputs = target_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
self.hybrid_mode = self.mtp_strategy == "with_ngram" and self.max_draft_token_num > self.num_model_steps
|
||||
|
||||
@@ -199,14 +201,16 @@ class MTPProposer(Proposer):
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.main_model_inputs["decoder_num_blocks_cpu"]
|
||||
self.target_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
|
||||
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(
|
||||
self.target_model_inputs["max_len_tensor_cpu"]
|
||||
).cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
@@ -265,24 +269,24 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
self.model_inputs = {}
|
||||
# Same shape/dytpe with base model
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.main_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.main_model_inputs["input_ids"])
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.main_model_inputs["seq_lens_this_time"])
|
||||
self.model_inputs["block_tables"] = paddle.clone(self.target_model_inputs["block_tables"])
|
||||
self.model_inputs["input_ids"] = paddle.clone(self.target_model_inputs["input_ids"])
|
||||
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_inputs["seq_lens_this_time"])
|
||||
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.main_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.main_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.main_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.main_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.main_model_inputs["stop_nums"])
|
||||
self.model_inputs["seq_lens_encoder"] = paddle.clone(self.target_model_inputs["seq_lens_encoder"])
|
||||
self.model_inputs["seq_lens_decoder"] = paddle.clone(self.target_model_inputs["seq_lens_decoder"])
|
||||
self.model_inputs["step_idx"] = paddle.clone(self.target_model_inputs["step_idx"])
|
||||
self.model_inputs["stop_flags"] = paddle.clone(self.target_model_inputs["stop_flags"])
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
||||
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.model_inputs["pre_ids"] = paddle.clone(self.main_model_inputs["pre_ids"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.main_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.main_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.main_model_inputs["cu_seqlens_q"])
|
||||
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.main_model_inputs["cu_seqlens_k"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
|
||||
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"])
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -294,22 +298,22 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
# self.model_inputs["caches"] = self.cache_kvs
|
||||
# Inherit generation hyperparameters from the main model for consistency
|
||||
self.model_inputs["top_p"] = self.main_model_inputs["top_p"]
|
||||
self.model_inputs["top_k"] = self.main_model_inputs["top_k"]
|
||||
self.model_inputs["temperature"] = self.main_model_inputs["temperature"]
|
||||
self.model_inputs["eos_token_id"] = self.main_model_inputs["eos_token_id"]
|
||||
self.model_inputs["penalty_score"] = self.main_model_inputs["penalty_score"]
|
||||
self.model_inputs["frequency_score"] = self.main_model_inputs["frequency_score"]
|
||||
self.model_inputs["presence_score"] = self.main_model_inputs["presence_score"]
|
||||
self.model_inputs["infer_seed"] = self.main_model_inputs["infer_seed"]
|
||||
self.model_inputs["top_p"] = self.target_model_inputs["top_p"]
|
||||
self.model_inputs["top_k"] = self.target_model_inputs["top_k"]
|
||||
self.model_inputs["temperature"] = self.target_model_inputs["temperature"]
|
||||
self.model_inputs["eos_token_id"] = self.target_model_inputs["eos_token_id"]
|
||||
self.model_inputs["penalty_score"] = self.target_model_inputs["penalty_score"]
|
||||
self.model_inputs["frequency_score"] = self.target_model_inputs["frequency_score"]
|
||||
self.model_inputs["presence_score"] = self.target_model_inputs["presence_score"]
|
||||
self.model_inputs["infer_seed"] = self.target_model_inputs["infer_seed"]
|
||||
|
||||
self.model_inputs["max_dec_len"] = self.main_model_inputs["max_dec_len"]
|
||||
self.model_inputs["min_dec_len"] = self.main_model_inputs["min_dec_len"]
|
||||
self.model_inputs["max_dec_len"] = self.target_model_inputs["max_dec_len"]
|
||||
self.model_inputs["min_dec_len"] = self.target_model_inputs["min_dec_len"]
|
||||
|
||||
self.model_inputs["bad_tokens"] = self.main_model_inputs["bad_tokens"]
|
||||
self.model_inputs["bad_tokens"] = self.target_model_inputs["bad_tokens"]
|
||||
|
||||
# Integrate the updated results in model forward
|
||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||
self.model_inputs["base_model_draft_tokens"] = self.target_model_inputs["draft_tokens"]
|
||||
self.model_inputs["substep"] = 0
|
||||
|
||||
# Declare AttentionBackend buffers
|
||||
@@ -323,7 +327,7 @@ class MTPProposer(Proposer):
|
||||
shape=[self.max_num_seqs, self.max_draft_token_num + 1], fill_value=-1, dtype="int64"
|
||||
)
|
||||
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.main_model_inputs["encoder_block_lens"])
|
||||
self.model_inputs["encoder_block_lens"] = paddle.clone(self.target_model_inputs["encoder_block_lens"])
|
||||
|
||||
self.free_list = list(
|
||||
range(
|
||||
@@ -337,14 +341,77 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["free_list"] = paddle.to_tensor(self.free_list, dtype="int32")
|
||||
self.model_inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||
|
||||
self.model_inputs["is_block_step"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["batch_drop"] = paddle.full(shape=[self.max_num_seqs, 1], fill_value=False, dtype="bool")
|
||||
self.model_inputs["used_list_len"] = paddle.full(shape=[self.max_num_seqs], fill_value=0, dtype="int32")
|
||||
if self.num_model_steps > 1:
|
||||
self.last_seq_lens_this_time = paddle.full_like(
|
||||
self.main_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
self.target_model_inputs["seq_lens_this_time"], fill_value=-1, dtype="int32"
|
||||
)
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
|
||||
|
||||
if "caches" not in self.model_inputs:
|
||||
self.initialize_kv_cache()
|
||||
req_len = len(req_dicts)
|
||||
# has_prefill_task = False
|
||||
# has_decode_task = False
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
logger.info(f"{i}th request-{request.request_id}: {request}")
|
||||
idx = request.idx
|
||||
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
||||
prefill_start_index = request.prefill_start_index
|
||||
prefill_end_index = request.prefill_end_index
|
||||
length = prefill_end_index - prefill_start_index
|
||||
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
|
||||
self.input_ids_len[idx] = length
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.model_inputs["batch_drop"][idx : idx + 1] = False
|
||||
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = length
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
|
||||
# has_prefill_task = True
|
||||
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.model_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
# if self.model_inputs["is_block_step"][idx]: # has tasks to continue to decode
|
||||
# has_decode_task = True
|
||||
# continue
|
||||
else:
|
||||
self.model_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.model_inputs["stop_flags"][idx : idx + 1] = True
|
||||
self.seq_lens_this_time_buffer[idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.model_inputs["is_block_step"][idx : idx + 1] = False
|
||||
continue
|
||||
# if has_prefill_task or has_decode_task:
|
||||
# self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
|
||||
"""
|
||||
Process inputs for prefill tasks and insert it to model_inputs buffer
|
||||
@@ -397,9 +464,9 @@ class MTPProposer(Proposer):
|
||||
length = len(request.prompt_token_ids)
|
||||
|
||||
if length > 1:
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.main_model_inputs["input_ids"][
|
||||
idx : idx + 1, 1:length
|
||||
]
|
||||
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
|
||||
"input_ids"
|
||||
][idx : idx + 1, 1:length]
|
||||
self.model_inputs["pre_ids"][idx : idx + 1] = -1
|
||||
self.model_inputs["step_idx"][idx : idx + 1] = 0
|
||||
if self.cache_config.enable_chunked_prefill:
|
||||
@@ -455,6 +522,7 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Prepare MTP inputs
|
||||
"""
|
||||
use_v1_cache_scheduler = envs.ENABLE_V1_KVCACHE_SCHEDULER
|
||||
draft_model_preprocess(
|
||||
self.model_inputs["draft_tokens"],
|
||||
self.model_inputs["input_ids"],
|
||||
@@ -465,19 +533,21 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["step_idx"],
|
||||
self.model_inputs["not_need_stop"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["is_block_step"],
|
||||
self.model_inputs["pre_ids"],
|
||||
self.main_model_inputs["accept_tokens"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["seq_lens_decoder"],
|
||||
self.main_model_inputs["step_idx"],
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.main_model_inputs["is_block_step"],
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.target_model_inputs["accept_tokens"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["seq_lens_decoder"],
|
||||
self.target_model_inputs["step_idx"],
|
||||
self.target_model_inputs["stop_flags"],
|
||||
self.target_model_inputs["is_block_step"],
|
||||
self.target_model_inputs["draft_tokens"],
|
||||
self.num_model_steps,
|
||||
self.speculative_method in ["eagle", "mtp"],
|
||||
self.role == "prefill",
|
||||
use_v1_cache_scheduler,
|
||||
)
|
||||
|
||||
target_hidden_states = eagle_get_hidden_states(
|
||||
@@ -486,9 +556,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.main_model_inputs["accept_num"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["accept_num"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.num_model_steps,
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
@@ -658,41 +728,41 @@ class MTPProposer(Proposer):
|
||||
Allocate/Free block of MPT.
|
||||
"""
|
||||
draft_model_postprocess(
|
||||
self.main_model_inputs["draft_tokens"],
|
||||
self.main_model_inputs["seq_lens_this_time"],
|
||||
self.main_model_inputs["seq_lens_encoder"],
|
||||
self.main_model_inputs["stop_flags"],
|
||||
)
|
||||
|
||||
mtp_step_paddle(
|
||||
self.main_model_inputs["stop_flags"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["block_tables"],
|
||||
self.model_inputs["encoder_block_lens"],
|
||||
self.model_inputs["used_list_len"],
|
||||
self.model_inputs["free_list"],
|
||||
self.model_inputs["free_list_len"],
|
||||
self.cache_config.block_size,
|
||||
self.max_draft_token_num,
|
||||
self.target_model_inputs["draft_tokens"],
|
||||
self.target_model_inputs["seq_lens_this_time"],
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.target_model_inputs["stop_flags"],
|
||||
)
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
mtp_step_paddle(
|
||||
self.target_model_inputs["stop_flags"],
|
||||
self.model_inputs["stop_flags"],
|
||||
self.model_inputs["batch_drop"],
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["seq_lens_encoder"],
|
||||
self.model_inputs["seq_lens_decoder"],
|
||||
self.model_inputs["block_tables"],
|
||||
self.model_inputs["encoder_block_lens"],
|
||||
self.model_inputs["used_list_len"],
|
||||
self.model_inputs["free_list"],
|
||||
self.model_inputs["free_list_len"],
|
||||
self.cache_config.block_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
|
||||
def _extend_draft_token_with_ngram_match(self):
|
||||
# TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency
|
||||
device = paddle.CUDAPinnedPlace()
|
||||
|
||||
draft_tokens = self.main_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.main_model_inputs["seq_lens_this_time"].cpu()
|
||||
draft_tokens = self.target_model_inputs["draft_tokens"].cpu()
|
||||
seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu()
|
||||
seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu()
|
||||
hybrid_mtp_ngram(
|
||||
self.model_inputs["input_ids"]._copy_to(device, True),
|
||||
self.input_ids_len,
|
||||
self.model_inputs["pre_ids"]._copy_to(device, True),
|
||||
self.model_inputs["step_idx"].cpu(),
|
||||
self.main_model_inputs["actual_draft_token_num"].cpu(),
|
||||
self.target_model_inputs["actual_draft_token_num"].cpu(),
|
||||
draft_tokens,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
@@ -701,8 +771,8 @@ class MTPProposer(Proposer):
|
||||
self.min_ngram_size,
|
||||
self.max_draft_token_num,
|
||||
)
|
||||
self.main_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.main_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda()
|
||||
self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
|
Reference in New Issue
Block a user