support mtp in v1_scheduler mode (#3695)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
freeliuzc
2025-09-04 17:39:59 +08:00
committed by GitHub
parent f265a26f8b
commit 88d44a2c93
11 changed files with 909 additions and 316 deletions

View File

@@ -58,6 +58,7 @@ else:
recover_decode_task,
set_value_by_flags_and_idx,
share_external_data,
speculate_schedule_cache,
)
from fastdeploy.model_executor.pre_and_post_process import (
@@ -394,6 +395,8 @@ class GPUModelRunner(ModelRunnerBase):
if has_prefill_task or has_decode_task:
self.share_inputs["not_need_stop"][0] = True
self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
if self.speculative_method in ["mtp"]:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int = None):
"""
@@ -815,6 +818,13 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype="int32",
)
# For V1_KVCACHE_SCHEDULER
self.share_inputs["step_draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1],
fill_value=0,
dtype="int64",
)
self.share_inputs["step_seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
if self.enable_mm:
head_dim = self.model_config.head_dim
@@ -853,7 +863,11 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["block_tables"],
self.share_inputs["is_block_step"],
self.share_inputs["draft_tokens"] if self.speculative_decoding else None,
self.share_inputs["step_draft_tokens"] if self.speculative_decoding else None,
self.share_inputs["step_seq_lens_this_time"] if self.speculative_decoding else None,
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens if self.speculative_decoding else 0,
)
# Remove padding
@@ -1556,6 +1570,24 @@ class GPUModelRunner(ModelRunnerBase):
self._update_chunked_prefill(model_forward_batch)
self._add_cache(model_forward_batch)
elif self.speculative_decoding:
speculate_schedule_cache(
self.share_inputs["draft_tokens"],
self.share_inputs["block_tables"],
self.share_inputs["stop_flags"],
self.share_inputs["seq_lens_this_time"],
self.share_inputs["seq_lens_decoder"],
self.share_inputs["step_seq_lens_decoder"],
self.share_inputs["step_draft_tokens"],
self.share_inputs["step_seq_lens_this_time"],
self.share_inputs["accept_num"],
self.share_inputs["accept_tokens"],
self.share_inputs["is_block_step"],
self.share_inputs["not_need_stop"],
self.share_inputs["stop_nums"],
self.cache_config.block_size,
self.speculative_config.num_speculative_tokens,
)
self.seq_lens_this_time_buffer[:num_running_requests].copy_(
self.share_inputs["seq_lens_this_time"][:num_running_requests], False