mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-01 12:22:53 +08:00
[Executor]CUDAGraph support Speculate Decode (#4258)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Executor]CUDAGraph support Speculate Decode
* fix problem
* solve problem
* fix
* fast compile
* CUDAGraph + mtp support eb5(only target model)
* Revert "fast compile"
This reverts commit 3cfe8373ed.
* fix precommit
* solve comment
* fix comment about #pragram unroll
---------
Co-authored-by: gongshaotian <gstain5555@outlook.com>
Co-authored-by: gongshaotian <gstian5555@outlook.com>
This commit is contained in:
@@ -33,31 +33,33 @@ class Proposer(ABC):
|
||||
the speculative decoding framework
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: FDConfig):
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
"""
|
||||
Init Speculative proposer
|
||||
"""
|
||||
cfg.parallel_config.tp_group = None
|
||||
cfg.parallel_config.ep_group = None
|
||||
self.cfg = deepcopy(cfg)
|
||||
cfg.parallel_config.tp_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
fd_config.parallel_config.tp_group = None
|
||||
fd_config.parallel_config.ep_group = None
|
||||
self.fd_config = deepcopy(fd_config)
|
||||
fd_config.parallel_config.tp_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
cfg.parallel_config.ep_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
fd_config.parallel_config.ep_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.cfg.parallel_config.tp_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
self.fd_config.parallel_config.tp_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
self.cfg.parallel_config.ep_group = dist.get_group(
|
||||
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
self.fd_config.parallel_config.ep_group = dist.get_group(
|
||||
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||
)
|
||||
|
||||
self.parallel_config = self.cfg.parallel_config
|
||||
self.model_config = self.cfg.model_config
|
||||
self.speculative_config = self.cfg.speculative_config
|
||||
self.cache_config = self.cfg.cache_config
|
||||
self.quant_config = self.cfg.quant_config
|
||||
self.parallel_config = self.fd_config.parallel_config
|
||||
self.model_config = self.fd_config.model_config
|
||||
self.speculative_config = self.fd_config.speculative_config
|
||||
self.cache_config = self.fd_config.cache_config
|
||||
self.quant_config = self.fd_config.quant_config
|
||||
self.graph_opt_config = self.fd_config.graph_opt_config
|
||||
self.scheduler_config = self.fd_config.scheduler_config
|
||||
|
||||
self.max_num_seqs = self.parallel_config.max_num_seqs
|
||||
self.max_model_len = self.parallel_config.max_model_len
|
||||
|
||||
@@ -22,6 +22,7 @@ import paddle
|
||||
from paddleformers.utils.log import logger
|
||||
|
||||
from fastdeploy import envs
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.engine.request import Request, RequestType
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||
@@ -31,6 +32,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
from fastdeploy.model_executor.models import ModelForCasualLM
|
||||
from fastdeploy.model_executor.ops.gpu import (
|
||||
draft_model_postprocess,
|
||||
draft_model_preprocess,
|
||||
@@ -52,12 +55,19 @@ class MTPProposer(Proposer):
|
||||
Proposer for Multi-Token-Prediction(MTP)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
|
||||
super().__init__(cfg)
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
main_model: ModelForCasualLM,
|
||||
local_rank: int,
|
||||
device_id: int, # physical device id
|
||||
target_model_inputs, # main model share inputs
|
||||
):
|
||||
super().__init__(fd_config)
|
||||
self.num_main_model_layers = self.model_config.num_hidden_layers
|
||||
self.local_rank = local_rank
|
||||
self.device_id = device_id
|
||||
self._update_cfg(main_model)
|
||||
self._update_mtp_config(main_model)
|
||||
self._load_model()
|
||||
self.target_model_inputs = target_model_inputs
|
||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||
@@ -65,16 +75,22 @@ class MTPProposer(Proposer):
|
||||
|
||||
# [mixed, prefill, decoder]
|
||||
self.role = "mixed"
|
||||
self.sampler = MTPSampler(cfg)
|
||||
self.sampler = MTPSampler(fd_config)
|
||||
self._init_model_inputs()
|
||||
|
||||
# CUDA Graph
|
||||
self.use_cudagraph = False # self.graph_opt_config.use_cudagraph
|
||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||
|
||||
self.attn_backends: list[AttentionBackend] = []
|
||||
self._initialize_attn_backend()
|
||||
|
||||
def _update_cfg(self, main_model):
|
||||
def _update_mtp_config(self, main_model):
|
||||
"""
|
||||
Update config for MTP from global config
|
||||
"""
|
||||
self.forward_meta: ForwardMeta = None
|
||||
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
||||
self.speculative_config.sharing_model = main_model
|
||||
self.model_config.num_hidden_layers = 1
|
||||
@@ -89,21 +105,18 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Load MTP Layer
|
||||
"""
|
||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||
|
||||
model_loader = get_model_loader(load_config=self.cfg.load_config)
|
||||
self.model = model_loader.load_model(fd_config=self.cfg)
|
||||
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
||||
self.model = model_loader.load_model(fd_config=self.fd_config)
|
||||
|
||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||
"""Set dummy prefill inputs to model_inputs"""
|
||||
max_dec_len = expected_decode_len + 1
|
||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||
self.initialize_kv_cache()
|
||||
full_length = min(
|
||||
input_length = min(
|
||||
num_tokens // batch_size,
|
||||
self.parallel_config.max_model_len - max_dec_len,
|
||||
)
|
||||
input_length = int(full_length * self.cache_config.kv_cache_ratio)
|
||||
|
||||
block_num = (
|
||||
input_length + self.cache_config.block_size - 1
|
||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||
@@ -125,13 +138,15 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
|
||||
def initialize_kv_cache(self):
|
||||
def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
|
||||
"""
|
||||
Initialize kv cache
|
||||
"""
|
||||
# prompt cache
|
||||
|
||||
self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
self.cache_kvs = {}
|
||||
|
||||
# Get kv cache dtype
|
||||
cache_type = self.parallel_config.dtype
|
||||
|
||||
kv_cache_quant_type = None
|
||||
@@ -151,9 +166,7 @@ class MTPProposer(Proposer):
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not self.parallel_config.do_profile and (
|
||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
||||
):
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
cache_kvs_list = []
|
||||
for i in range(
|
||||
self.num_main_model_layers,
|
||||
@@ -230,7 +243,7 @@ class MTPProposer(Proposer):
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
self.cfg,
|
||||
self.fd_config,
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
@@ -243,7 +256,7 @@ class MTPProposer(Proposer):
|
||||
)
|
||||
self.attn_backends.append(attn_backend)
|
||||
|
||||
def clear_dummy_input(self):
|
||||
def clear_mtp_cache(self):
|
||||
"""
|
||||
Clear allocated cacheKV
|
||||
"""
|
||||
@@ -251,15 +264,14 @@ class MTPProposer(Proposer):
|
||||
if self.forward_meta is not None:
|
||||
del self.forward_meta.caches
|
||||
|
||||
def update_block_num(self, num_gpu_blocks) -> None:
|
||||
def update_mtp_block_num(self, num_gpu_blocks) -> None:
|
||||
"""
|
||||
Update block num by theoretical calculation
|
||||
Update MTP block num by theoretical calculation
|
||||
"""
|
||||
|
||||
# Reset block table and kv cache with global block num
|
||||
self.main_model_num_gpu_blocks = num_gpu_blocks
|
||||
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
self.initialize_kv_cache()
|
||||
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
|
||||
|
||||
# Reset free list
|
||||
free_list = list(
|
||||
@@ -276,7 +288,6 @@ class MTPProposer(Proposer):
|
||||
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
||||
}
|
||||
)
|
||||
self.parallel_config.do_profile = False
|
||||
|
||||
def _init_model_inputs(self):
|
||||
"""
|
||||
@@ -300,6 +311,8 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
||||
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
|
||||
self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"])
|
||||
self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"])
|
||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
|
||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
|
||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
|
||||
@@ -308,6 +321,9 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
self.model_inputs["rope_emb"] = get_rope(
|
||||
@@ -443,9 +459,6 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
Process inputs for prefill tasks and insert it to model_inputs buffer
|
||||
"""
|
||||
# NOTE: Lazy initialize kv cache
|
||||
if "caches" not in self.model_inputs:
|
||||
self.initialize_kv_cache()
|
||||
|
||||
# TODO:Init role in initialize process
|
||||
if req_dicts[-1].disaggregate_info is not None:
|
||||
@@ -526,7 +539,7 @@ class MTPProposer(Proposer):
|
||||
request.get("block_tables"), dtype="int32"
|
||||
)
|
||||
self.model_inputs["not_need_stop"][0] = True
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||
|
||||
def _initialize_forward_meta(self):
|
||||
"""
|
||||
@@ -556,6 +569,33 @@ class MTPProposer(Proposer):
|
||||
for attn_backend in self.attn_backends:
|
||||
attn_backend.init_attention_metadata(self.forward_meta)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
|
||||
# Mix ep in single node
|
||||
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||
only_decode_batch_list = []
|
||||
prefill_exists = self.exist_prefill()
|
||||
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||
only_decode_batch = all(only_decode_batch_list)
|
||||
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||
|
||||
self.forward_meta.step_use_cudagraph = (
|
||||
self.use_cudagraph
|
||||
and only_decode_batch
|
||||
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
||||
)
|
||||
|
||||
def exist_prefill(self):
|
||||
"""
|
||||
check whether prefill stage exist
|
||||
"""
|
||||
if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _prepare_inputs(self, full_hidden_states):
|
||||
"""
|
||||
Prepare MTP inputs
|
||||
@@ -599,10 +639,8 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["seq_lens_encoder"],
|
||||
self.num_model_steps,
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
target_hidden_states = target_hidden_states[0]
|
||||
|
||||
return target_hidden_states
|
||||
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
||||
|
||||
def _post_process(self, sampled_token_ids):
|
||||
"""
|
||||
@@ -633,7 +671,7 @@ class MTPProposer(Proposer):
|
||||
self.parallel_config.use_ep,
|
||||
)
|
||||
|
||||
def _propose(self, target_hidden_states):
|
||||
def _propose(self):
|
||||
"""
|
||||
Main process for MTP inference
|
||||
"""
|
||||
@@ -663,10 +701,15 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||
# for speculative decoding
|
||||
self.model_inputs["output_cum_offsets"] = output_cum_offsets
|
||||
self.model_inputs["output_padding_offset"] = output_padding_offset
|
||||
self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||
|
||||
# Initialize forward meta data
|
||||
self._initialize_forward_meta()
|
||||
|
||||
# Padding inputs for cuda graph
|
||||
self.padding_cudagraph_inputs()
|
||||
|
||||
# Get sampling metadata
|
||||
self.sampling_metadata = SamplingMetadata(
|
||||
temperature=self.model_inputs["temperature"],
|
||||
@@ -687,9 +730,11 @@ class MTPProposer(Proposer):
|
||||
|
||||
model_output = self.model(
|
||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||
previous_hidden_states=target_hidden_states,
|
||||
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
||||
forward_meta=self.forward_meta,
|
||||
)
|
||||
if self.use_cudagraph:
|
||||
model_output = model_output[: self.real_token_num]
|
||||
|
||||
hidden_states = rebuild_padding(
|
||||
model_output,
|
||||
@@ -721,7 +766,7 @@ class MTPProposer(Proposer):
|
||||
self._post_process(sampled_token_ids)
|
||||
|
||||
if substep != self.num_model_steps - 1:
|
||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
||||
self._get_self_hidden_states(hidden_states)
|
||||
else:
|
||||
if hasattr(self.model, "empty_input_forward"):
|
||||
self.model.empty_input_forward()
|
||||
@@ -733,10 +778,7 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["seq_lens_this_time"],
|
||||
self.model_inputs["step_idx"],
|
||||
)
|
||||
if isinstance(target_hidden_states, list):
|
||||
target_hidden_states = target_hidden_states[0]
|
||||
|
||||
return target_hidden_states
|
||||
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
||||
|
||||
def update_task_chunk_prefill(self, task):
|
||||
"""
|
||||
@@ -821,8 +863,8 @@ class MTPProposer(Proposer):
|
||||
|
||||
def _run_impl(self, full_hidden_states):
|
||||
""""""
|
||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
||||
self._propose(target_hidden_states=target_hidden_states)
|
||||
self._prepare_inputs(full_hidden_states)
|
||||
self._propose()
|
||||
self._update_status()
|
||||
if self.hybrid_mode:
|
||||
self._extend_draft_token_with_ngram_match()
|
||||
@@ -830,3 +872,16 @@ class MTPProposer(Proposer):
|
||||
def is_chunk_prefill_enabled(self):
|
||||
""""""
|
||||
return True
|
||||
|
||||
def padding_cudagraph_inputs(self) -> None:
|
||||
"""
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||
"""
|
||||
# In init_attention_metadata, the decode buffer has already been cleared
|
||||
|
||||
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
||||
if self.use_cudagraph:
|
||||
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
||||
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||
return
|
||||
|
||||
@@ -29,8 +29,8 @@ class NgramProposer(Proposer):
|
||||
Matching corresponding tokens in input and output as draft tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: FDConfig):
|
||||
super().__init__(cfg)
|
||||
def __init__(self, fd_config: FDConfig):
|
||||
super().__init__(fd_config)
|
||||
self.max_ngram_size = self.speculative_config.max_ngram_size
|
||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user