mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support block scheduler v1 for FD (#2928)
* Support FD block scheduler v1 * Support FD block scheduler v1 * Support FD block scheduler v1 * Fix according to copilot review * Fix according to review * Remove is_dummy * Fix bug when real_bsz=1 * Fix infer first token cost time --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -24,7 +24,7 @@ from paddle import nn
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.model_executor.graph_optimization.utils import (
|
||||
profile_run_guard,
|
||||
sot_warmup_guard,
|
||||
@@ -42,6 +42,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import Sampler, SpeculativeSampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_from_loader
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
recover_decode_task,
|
||||
set_value_by_flags_and_idx,
|
||||
share_external_data,
|
||||
)
|
||||
@@ -56,6 +57,7 @@ from fastdeploy.platforms import current_platform
|
||||
if not current_platform.is_dcu():
|
||||
from fastdeploy.spec_decode import MTPProposer, NgramProposer
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.input.mm_processor import DataProcessor
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ScatterOp
|
||||
@@ -189,10 +191,97 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
elif request.structural_tag is not None:
|
||||
schemata_key = ("structural_tag", request.structural_tag)
|
||||
|
||||
return (
|
||||
self.guided_backend.get_logits_processor(schemata_key=schemata_key),
|
||||
schemata_key,
|
||||
)
|
||||
return self.guided_backend.get_logits_processor(schemata_key=schemata_key), schemata_key
|
||||
|
||||
def insert_tasks_v1(self, req_dicts: List[Request]):
|
||||
"""
|
||||
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1
|
||||
"""
|
||||
# NOTE(luotingdan): Lazy initialize kv cache
|
||||
if "caches" not in self.share_inputs:
|
||||
self.initialize_kv_cache()
|
||||
|
||||
req_len = len(req_dicts)
|
||||
has_prefill_task = False
|
||||
for i in range(req_len):
|
||||
request = req_dicts[i]
|
||||
idx = request.idx
|
||||
if request.task_type.value == RequestType.PREFILL.value: # prefill task
|
||||
logger.debug(f"Handle prefill request {request} at idx {idx}")
|
||||
prefill_start_index = request.prefill_start_index
|
||||
prefill_end_index = request.prefill_end_index
|
||||
length = prefill_end_index - prefill_start_index
|
||||
input_ids = request.prompt_token_ids + request.output_token_ids
|
||||
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
|
||||
input_ids[prefill_start_index:prefill_end_index]
|
||||
)
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
self.share_inputs["stop_flags"][idx : idx + 1] = False
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = prefill_start_index
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length
|
||||
self.share_inputs["step_seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["prompt_lens"][idx : idx + 1] = len(input_ids)
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
self.share_inputs["step_idx"][idx : idx + 1] = (
|
||||
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
|
||||
)
|
||||
has_prefill_task = True
|
||||
elif request.task_type.value == RequestType.DECODE.value: # decode task
|
||||
logger.debug(f"Handle decode request {request} at idx {idx}")
|
||||
encoder_block_num = len(request.block_tables)
|
||||
self.share_inputs["encoder_block_lens"][idx : idx + 1] = encoder_block_num
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
|
||||
request.block_tables, dtype="int32"
|
||||
)
|
||||
continue
|
||||
else: # preempted task
|
||||
logger.debug(f"Handle preempted request {request} at idx {idx}")
|
||||
self.share_inputs["block_tables"][idx : idx + 1, :] = -1
|
||||
self.share_inputs["stop_flags"][idx : idx + 1] = True
|
||||
self.share_inputs["seq_lens_this_time"][idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
|
||||
self.share_inputs["is_block_step"][idx : idx + 1] = False
|
||||
continue
|
||||
|
||||
if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens:
|
||||
request.eos_token_ids.append(request.eos_token_ids[0])
|
||||
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
|
||||
|
||||
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
|
||||
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
|
||||
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
|
||||
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
|
||||
self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0)
|
||||
|
||||
self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1)
|
||||
self.share_inputs["max_dec_len"][idx : idx + 1] = request.get(
|
||||
"max_tokens", self.model_config.max_model_len
|
||||
)
|
||||
|
||||
self.share_inputs["first_token_ids"][idx : idx + 1] = self.share_inputs["input_ids"][idx : idx + 1, :1]
|
||||
self.share_inputs["ori_seq_lens_encoder"][idx : idx + 1] = length
|
||||
|
||||
if request.get("seed") is not None:
|
||||
self.share_inputs["infer_seed"][idx : idx + 1] = request.get("seed")
|
||||
|
||||
if request.get("stop_token_ids") is not None and request.get("stop_seqs_len") is not None:
|
||||
stop_seqs_num = len(request.get("stop_seqs_len"))
|
||||
for i in range(stop_seqs_num, self.model_config.max_stop_seqs_num):
|
||||
request.stop_seqs_len.append(0)
|
||||
self.share_inputs["stop_seqs_len"][:] = np.array(request.stop_seqs_len, dtype="int32")
|
||||
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
|
||||
request.get("stop_token_ids"), dtype="int64"
|
||||
)
|
||||
if has_prefill_task:
|
||||
self.share_inputs["not_need_stop"][0] = True
|
||||
|
||||
def insert_prefill_inputs(self, req_dicts: List[Request]):
|
||||
"""
|
||||
@@ -591,6 +680,18 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
|
||||
def _prepare_inputs(self) -> None:
|
||||
"""Prepare the model inputs"""
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
recover_decode_task(
|
||||
self.share_inputs["stop_flags"],
|
||||
self.share_inputs["seq_lens_this_time"],
|
||||
self.share_inputs["seq_lens_encoder"],
|
||||
self.share_inputs["seq_lens_decoder"],
|
||||
self.share_inputs["step_seq_lens_decoder"],
|
||||
self.share_inputs["block_tables"],
|
||||
self.share_inputs["is_block_step"],
|
||||
self.parallel_config.block_size,
|
||||
)
|
||||
|
||||
# Remove padding
|
||||
(
|
||||
ids_remove_padding,
|
||||
@@ -901,6 +1002,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
post_process(
|
||||
sampler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
block_size=self.parallel_config.block_size,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=True,
|
||||
)
|
||||
@@ -1165,6 +1268,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
post_process(
|
||||
sampler_output=sampler_output,
|
||||
model_output=model_output_data,
|
||||
share_inputs=self.share_inputs,
|
||||
block_size=self.parallel_config.block_size,
|
||||
save_each_rank=self.parallel_config.use_ep,
|
||||
speculative_decoding=self.speculative_decoding,
|
||||
skip_save_output=skip_save_output,
|
||||
@@ -1180,16 +1285,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
# 7. Updata 'infer_seed' and step_cuda()
|
||||
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
|
||||
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
|
||||
step_cuda(
|
||||
self.share_inputs,
|
||||
self.parallel_config.block_size,
|
||||
self.parallel_config.enc_dec_block_num,
|
||||
self.speculative_config,
|
||||
self.parallel_config.enable_prefix_caching,
|
||||
)
|
||||
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
step_cuda(
|
||||
self.share_inputs,
|
||||
self.parallel_config.block_size,
|
||||
self.parallel_config.enc_dec_block_num,
|
||||
self.speculative_config,
|
||||
self.parallel_config.enable_prefix_caching,
|
||||
)
|
||||
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
self._update_chunked_prefill(model_forward_batch)
|
||||
self._add_cache(model_forward_batch)
|
||||
return None
|
||||
|
||||
def _add_cache(self, model_forward_batch) -> None:
|
||||
|
@@ -22,6 +22,7 @@ import paddle
|
||||
import pynvml
|
||||
from paddle import nn
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request
|
||||
from fastdeploy.platforms import current_platform
|
||||
@@ -183,7 +184,10 @@ class GpuWorker(WorkerBase):
|
||||
TODO(gongshaotian):The scheduler should schedule the handling of prefill,
|
||||
and workers and modelrunners should not perceive it.
|
||||
"""
|
||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
self.model_runner.insert_tasks_v1(req_dicts=req_dicts)
|
||||
else:
|
||||
self.model_runner.insert_prefill_inputs(req_dicts=req_dicts)
|
||||
|
||||
def graph_optimize_and_warm_up_model(self) -> None:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user