[Cherry-Pick][BugFix] Add prefill restrictions for chunked_prefill+VL (#2984)

This commit is contained in:
Zero Rains
2025-07-23 16:53:26 +08:00
committed by GitHub
parent e5804b1d98
commit abd238fc12
2 changed files with 21 additions and 16 deletions

View File

@@ -140,7 +140,14 @@ class GPUModelRunner(ModelRunnerBase):
"""
Check whether prefill stage finished
"""
if int(paddle.max(self.share_inputs['seq_lens_encoder'])) != 0:
if self.enable_mm:
# VL only support 1 batch to prefill
prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & (
self.share_inputs["seq_lens_this_time"] != 1
)
return not paddle.any(prefill_statue).numpy()
else:
if int(paddle.max(self.share_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0

View File

@@ -23,10 +23,10 @@ import paddle
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from fastdeploy.config import (DecodingConfig, DeviceConfig, FDConfig,
from fastdeploy.config import (DecodingConfig, DeviceConfig,
ErnieArchitectures, FDConfig,
GraphOptimizationConfig, LoadConfig,
ModelConfig, ParallelConfig, SpeculativeConfig,
ErnieArchitectures)
ModelConfig, ParallelConfig, SpeculativeConfig)
from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.inter_communicator import EngineWorkerQueue as TaskQueue
from fastdeploy.inter_communicator import IPCSignal
@@ -277,12 +277,12 @@ class PaddleDisWorkerProc():
# The first worker detects whether there are tasks in the task queue
if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0:
# VL only support 1 batch to prefill
if not self.fd_config.model_config.enable_mm or self.worker.prefill_finished():
if self.nnode > 1:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[
self.fd_config.parallel_config.
expert_parallel_rank] = 1
self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] = 1
if self.parallel_config.tensor_parallel_size > 1:
# Synchronize the signal for other workers
@@ -332,10 +332,8 @@ class PaddleDisWorkerProc():
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
self.worker.execute_model(req_dicts)
self.exist_prefill_task_signal.value[
0] = self.worker.prefill_finished()
if not self.fd_config.model_config.enable_mm:
self.exist_prefill_task_signal.value[0] = self.worker.prefill_finished()
def determine_num_available_blocks(self) -> None:
"""Profiles the peak memory usage of the model to determine how many