mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature][MTP]Support MTP for rl-model (#4009)
* qk norm for speculate decode C16 * support mtp in v1_scheduler mode * support mtp rope_3d * support mtp features * add unit test && del some log --------- Co-authored-by: yuanxiaolan <yuanxiaolan01@baidu.com> Co-authored-by: xiaoxiaohehe001 <hiteezsf@163.com>
This commit is contained in:
@@ -59,6 +59,7 @@ else:
|
||||
recover_decode_task,
|
||||
set_value_by_flags_and_idx,
|
||||
share_external_data,
|
||||
speculate_schedule_cache,
|
||||
)
|
||||
|
||||
from fastdeploy.model_executor.pre_and_post_process import (
|
||||
@@ -383,6 +384,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["stop_flags"].sum() == self.parallel_config.max_num_seqs
|
||||
)
|
||||
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
if self.speculative_method in ["mtp"]:
|
||||
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
|
||||
"""
|
||||
@@ -803,6 +806,13 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
fill_value=0,
|
||||
dtype="int32",
|
||||
)
|
||||
# For V1_KVCACHE_SCHEDULER
|
||||
self.share_inputs["step_draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1],
|
||||
fill_value=0,
|
||||
dtype="int64",
|
||||
)
|
||||
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
|
||||
if self.enable_mm:
|
||||
head_dim = self.model_config.head_dim
|
||||
@@ -841,7 +851,11 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"],
|
||||
self.share_inputs["block_tables"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
|
||||
self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None,
|
||||
self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None,
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
|
||||
)
|
||||
|
||||
# Remove padding
|
||||
@@ -1540,6 +1554,24 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
elif self.speculative_decoding:
|
||||
speculate_schedule_cache(
|
||||
self.share_inputs["draft_tokens"],
|
||||
self.share_inputs["block_tables"],
|
||||
self.share_inputs["stop_flags"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["step_seq_lens_decoder"],
|
||||
self.share_inputs["step_draft_tokens"],
|
||||
self.share_inputs["step_seq_lens_this_time"],
|
||||
self.share_inputs["accept_num"],
|
||||
self.share_inputs["accept_tokens"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.share_inputs["not_need_stop"],
|
||||
self.share_inputs["stop_nums"],
|
||||
self.cache_config.block_size,
|
||||
self.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
|
||||
self.share_inputs["seq_lens_this_time"][:num_running_requests], False
|
||||
|
Reference in New Issue
Block a user