[Executor]CUDAGraph support Speculate Decode (#3769)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* success run ngram

* Revert "[Code Simplification] remove cum_offsets (#3410)"

This reverts commit 32b39620bc.

* success run ngram5 tp4 42bs

* success run ngram5 tp4 42bs

* mtp draft commit

* add decorator for target model

* enable draft model in cudagraph v0.5

* revert revrt cum_offset

* enable target model in cudagraph v0.9 And clean debug code

* Revert "success run ngram"

This reverts commit 8351e83993.

* add reverted code

* enable target model in cudagraph v0.9

* solve comment

* fix bid < 0

* Enable Target Model Padding And Draft Model in cudagraph

* solve problem

* delete rebuild padding debug note

* fast compile

* Add capture list for mtp

* success run 256 tp1 mtp

* Enable Lite TP2 Bsz256

* realy enable tp2 bsz 256

* fix problem

* Solve problem for Draft model in cudagraph

* Solve comment

* replace emptytensor as zeros

* Solve comments

* Revert "fast compile"

This reverts commit 834639a7ff.

* fix bug

* fix merge bug

* fix typo

* fix bug

---------

Co-authored-by: lizexu <2694294196@qq.com>
Co-authored-by: littledgg <1658565283@qq.com>
Co-authored-by: zeroRains <linjunlu@zerorains.top>
Co-authored-by: gongshaotian <gstain5555@outlook.com>
This commit is contained in:
RAM
2025-10-09 21:18:29 +08:00
committed by GitHub
parent 7b1689f437
commit aa27b03bc0
19 changed files with 250 additions and 139 deletions

View File

@@ -494,12 +494,12 @@ std::vector<paddle::Tensor> AppendAttention(
paddle::Tensor fmha_out;
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::INT8,
qkv.place());
} else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
@@ -507,7 +507,7 @@ std::vector<paddle::Tensor> AppendAttention(
PD_THROW("Only supported attr of quant_max_bound in ['127', '448'].");
}
} else {
fmha_out = GetEmptyTensor(
fmha_out = paddle::zeros(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
dtype_id,
qkv.place());

View File

@@ -2418,6 +2418,9 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
@@ -2437,6 +2440,8 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}else if (!ENABLE_PREFILL){
continue;
}
using LoadT = AlignedVector<T, vec_size>;

View File

@@ -84,15 +84,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;
@@ -390,15 +382,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;
@@ -525,15 +509,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) {
printf(
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
return ; // NOTE(gongshaotian): For CUDAGraph padding
}
const int block_offset = write_seq_id % block_size;

View File

@@ -684,7 +684,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts);
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,

View File

@@ -130,7 +130,6 @@ std::vector<paddle::Tensor> rebuild_padding(
int pack_num = elem_nums / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;
if (output_padding_offset) {
RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(

View File

@@ -139,7 +139,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids",
"draft_tokens",
"cum_offsets"
"cum_offsets",
"token_num",
"seq_len",
"seq_lens_encoder"})

View File

@@ -73,7 +73,7 @@ __global__ void speculate_verify(
const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop, const bool benchmark_mode) {
const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) {
const int bid = threadIdx.x;
// verify and set stop flags
int accept_num_now = 1;
@@ -101,6 +101,24 @@ __global__ void speculate_verify(
if (seq_lens_encoder[bid] != 0) {
break;
}
if (accept_all_drafts) {
// accept all draft tokens
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
break;
} else {
accept_num_now++;
}
continue;
}
if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) {
@@ -249,7 +267,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) {
// printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0];
@@ -292,7 +310,7 @@ void SpeculateVerify(
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop, benchmark_mode);
prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else {
speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -308,7 +326,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
}
} else {
if (enable_topp) {
@@ -326,7 +344,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else {
speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -342,7 +360,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
}
}
@@ -357,7 +375,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
"actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
"stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"},

View File

@@ -1437,6 +1437,11 @@ class FDConfig:
if self.graph_opt_config.cudagraph_only_prefill:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=512)
elif self.speculative_config is not None and self.speculative_config.method == "mtp":
max_shape = self.scheduler_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if max_shape % 2 == 1:
max_shape = max_shape + 1
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=min(512, max_shape))
else:
self.graph_opt_config.init_with_cudagrpah_size(max_capture_size=self.scheduler_config.max_num_seqs)

View File

@@ -167,7 +167,7 @@ class ForwardMeta:
"shape": obj.shape,
"dtype": str(obj.dtype),
"place": str(obj.place),
# "content": obj if obj.numel()<10 else "Too big to show"
"content": obj if obj.numel() < 70 else "Too big to show",
}
return tensor_info
elif isinstance(obj, (list, tuple)):

View File

@@ -121,7 +121,7 @@ class CudaGraphPiecewiseBackend:
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, "
f"[CUDA GRAPH][ID:{id(self)}] Warm up for batch size {entry.real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
@@ -148,15 +148,15 @@ class CudaGraphPiecewiseBackend:
real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug(
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}"
f"[CUDA GRAPH][ID:{id(self)}] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}, If Padding :{real_shape != padding_real_shape}"
)
entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None:
entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] New entry lazy initialize with real shape {padding_real_shape}")
if not entry.use_cudagraph:
return entry.runnable(**kwargs)
@@ -171,7 +171,7 @@ class CudaGraphPiecewiseBackend:
entry.num_finished_warmup += 1
entry.runnable(**kwargs)
logger.debug(
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
f"[CUDA GRAPH][ID:{id(self)}] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
)
@@ -206,11 +206,11 @@ class CudaGraphPiecewiseBackend:
# For CUDAGraph debug
# self._save_cudagrpah_dot_files(entry)
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}")
# Replay
entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph replayed for real shape {padding_real_shape}")
if len(entry.output_buffers) == 1:
return entry.output_buffers[0]
return entry.output_buffers
@@ -223,18 +223,19 @@ class CudaGraphPiecewiseBackend:
for shape in self.cudagraph_capture_sizes:
self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape)
logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
logger.debug(
f"[CUDA GRAPH][ID:{id(self)}] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
"Created all real shape entry."
)
def clear_graph(self):
""" """
# Clear graphs
custom_ar_clear_ipc_handles()
for id, entry in self.concrete_size_entries.items():
for _id, entry in self.concrete_size_entries.items():
if entry.cuda_graph:
del entry.cuda_graph
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.")
logger.debug(f"[CUDA GRAPH][ID:{id(self)}] The CUDAGraph with shape {_id} has been cleared.")
del self.concrete_size_entries
paddle.device.cuda.empty_cache()

View File

@@ -115,7 +115,7 @@ class GraphOptBackend:
self.runnable = runnable
self.fd_config = fd_config
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda graph input buffers (contain output of subgraphs)
@@ -138,9 +138,9 @@ class GraphOptBackend:
)
assert kwargs["forward_meta"].ids_remove_padding is not None
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0]
if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch):
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size):
return self.runnable(**kwargs)
else:
return self.cudagraph_piecewise_backend.__call__(**kwargs)

View File

@@ -461,6 +461,7 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata: SamplingMetadata,
max_model_len: int,
share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
) -> paddle.Tensor:
""" """
@@ -517,6 +518,7 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window,
True, # enable_topp
self.speculative_benchmark_mode,
accept_all_drafts,
)
return None

View File

@@ -28,6 +28,9 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
@@ -234,6 +237,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
return mappings
@support_graph_optimization
class Ernie4_5_MTPModel(nn.Layer):
"""
Ernie4_5_MTPModel
@@ -457,6 +461,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
"""
forward
"""
hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta)
hidden_states = self.ernie(
ids_remove_padding=ids_remove_padding,
previous_hidden_states=previous_hidden_states,
forward_meta=forward_meta,
)
return hidden_states

View File

@@ -33,24 +33,25 @@ class Proposer(ABC):
the speculative decoding framework
"""
def __init__(self, cfg: FDConfig):
def __init__(self, fd_config: FDConfig):
"""
Init Speculative proposer
"""
cfg.parallel_config.tp_group = None
self.cfg = deepcopy(cfg)
cfg.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
fd_config.parallel_config.tp_group = None
self.fd_config = deepcopy(fd_config)
fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.cfg.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
self.fd_config.parallel_config.tp_group = dist.get_group(
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
)
self.parallel_config = self.cfg.parallel_config
self.model_config = self.cfg.model_config
self.speculative_config = self.cfg.speculative_config
self.cache_config = self.cfg.cache_config
self.quant_config = self.cfg.quant_config
self.scheduler_config = self.cfg.scheduler_config
self.parallel_config = self.fd_config.parallel_config
self.model_config = self.fd_config.model_config
self.speculative_config = self.fd_config.speculative_config
self.cache_config = self.fd_config.cache_config
self.quant_config = self.fd_config.quant_config
self.graph_opt_config = self.fd_config.graph_opt_config
self.scheduler_config = self.fd_config.scheduler_config
self.max_num_seqs = self.scheduler_config.max_num_seqs
self.max_model_len = self.parallel_config.max_model_len

View File

@@ -19,7 +19,6 @@ from typing import List
import numpy as np
import paddle
from paddle import nn
from paddleformers.utils.log import logger
from fastdeploy import envs
@@ -33,6 +32,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models import ModelForCasualLM
from fastdeploy.model_executor.ops.gpu import (
draft_model_postprocess,
draft_model_preprocess,
@@ -54,12 +55,19 @@ class MTPProposer(Proposer):
Proposer for Multi-Token-Prediction(MTP)
"""
def __init__(self, cfg: FDConfig, main_model: nn.Layer, local_rank: int, device_id: int, target_model_inputs):
super().__init__(cfg)
def __init__(
self,
fd_config: FDConfig,
main_model: ModelForCasualLM,
local_rank: int,
device_id: int, # physical device id
target_model_inputs, # main model share inputs
):
super().__init__(fd_config)
self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank
self.device_id = device_id
self._update_cfg(main_model)
self._update_mtp_config(main_model)
self._load_model()
self.target_model_inputs = target_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy
@@ -67,13 +75,22 @@ class MTPProposer(Proposer):
# [mixed, prefill, decoder]
self.role = "mixed"
self.sampler = MTPSampler(cfg)
self.sampler = MTPSampler(fd_config)
self._init_model_inputs()
# CUDA Graph
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.attn_backends: list[AttentionBackend] = []
self._initialize_attn_backend()
def _update_cfg(self, main_model):
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None
def _update_mtp_config(self, main_model):
"""
Update config for MTP from global config
"""
@@ -91,21 +108,17 @@ class MTPProposer(Proposer):
"""
Load MTP Layer
"""
from fastdeploy.model_executor.model_loader import get_model_loader
model_loader = get_model_loader(load_config=self.cfg.load_config)
self.model = model_loader.load_model(fd_config=self.cfg)
model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.fd_config)
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to model_inputs"""
max_dec_len = expected_decode_len + 1
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache()
full_length = min(
input_length = min(
num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len,
)
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = (
input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
@@ -127,15 +140,15 @@ class MTPProposer(Proposer):
)
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def initialize_kv_cache(self):
def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
"""
Initialize kv cache
"""
# prompt cache
self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.cache_kvs = {}
# Get kv cache dtype
cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if (
self.quant_config
@@ -149,7 +162,7 @@ class MTPProposer(Proposer):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
)
if not self.parallel_config.do_profile and (
if not profile and (
self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"
):
cache_kvs_list = []
@@ -239,7 +252,7 @@ class MTPProposer(Proposer):
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
self.cfg,
self.fd_config,
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim,
@@ -252,7 +265,7 @@ class MTPProposer(Proposer):
)
self.attn_backends.append(attn_backend)
def clear_dummy_input(self):
def clear_mtp_cache(self):
"""
Clear allocated cacheKV
"""
@@ -260,15 +273,13 @@ class MTPProposer(Proposer):
if self.forward_meta is not None:
del self.forward_meta.caches
def update_block_num(self, num_gpu_blocks) -> None:
def update_mtp_block_num(self, num_gpu_blocks) -> None:
"""
Update block num by theoretical calculation
Update MTP block num by theoretical calculation
"""
# Reset block table and kv cache with global block num
self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
if not (self.cache_config.enable_prefix_caching or self.scheduler_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
# Reset free list
free_list = list(
@@ -285,7 +296,6 @@ class MTPProposer(Proposer):
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
}
)
self.parallel_config.do_profile = False
def _init_model_inputs(self):
"""
@@ -309,14 +319,20 @@ class MTPProposer(Proposer):
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"])
self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"])
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
self.model_inputs["cu_seqlens_k"] = paddle.clone(self.target_model_inputs["cu_seqlens_k"])
self.model_inputs["decoder_batch_ids"] = paddle.clone(self.target_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["target_hidden_states"] = paddle.full(
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
)
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
self.model_inputs["rope_emb"] = get_rope(
@@ -457,10 +473,6 @@ class MTPProposer(Proposer):
"""
Process inputs for prefill tasks and insert it to model_inputs buffer
"""
# NOTE: Lazy initialize kv cache
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
# TODO:Init role in initialize process
if req_dicts[-1].disaggregate_info is not None:
if req_dicts[-1].disaggregate_info["role"] == "prefill":
@@ -539,7 +551,7 @@ class MTPProposer(Proposer):
request.get("block_tables"), dtype="int32"
)
self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def _initialize_forward_meta(self):
"""
@@ -578,6 +590,33 @@ class MTPProposer(Proposer):
for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta)
# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None
# Mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.scheduler_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
def _prepare_inputs(self, full_hidden_states):
"""
Prepare MTP inputs
@@ -621,10 +660,8 @@ class MTPProposer(Proposer):
self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps,
)
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]
return target_hidden_states
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def _post_process(self, sampled_token_ids):
"""
@@ -655,7 +692,7 @@ class MTPProposer(Proposer):
self.parallel_config.use_ep,
)
def _propose(self, target_hidden_states):
def _propose(self):
"""
Main process for MTP inference
"""
@@ -684,11 +721,17 @@ class MTPProposer(Proposer):
self.model_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
# for speculative decoding
self.model_inputs["output_cum_offsets"] = output_cum_offsets
self.model_inputs["output_padding_offset"] = output_padding_offset
# For speculative decoding
self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
# Initialize forward meta data
self._initialize_forward_meta()
# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(
temperature=self.model_inputs["temperature"],
@@ -709,10 +752,11 @@ class MTPProposer(Proposer):
model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=target_hidden_states,
previous_hidden_states=self.model_inputs["target_hidden_states"],
forward_meta=self.forward_meta,
)
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding(
model_output,
self.model_inputs["cu_seqlens_q"],
@@ -737,9 +781,8 @@ class MTPProposer(Proposer):
paddle.distributed.broadcast(sampled_token_ids, 0)
self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1:
target_hidden_states = self._get_self_hidden_states(hidden_states)
self._get_self_hidden_states(hidden_states)
def _get_self_hidden_states(self, hidden_states):
target_hidden_states = eagle_get_self_hidden_states(
@@ -748,10 +791,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"],
)
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]
return target_hidden_states
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def update_task_chunk_prefill(self, task):
"""
@@ -836,8 +876,8 @@ class MTPProposer(Proposer):
def _run_impl(self, full_hidden_states):
""""""
target_hidden_states = self._prepare_inputs(full_hidden_states)
self._propose(target_hidden_states=target_hidden_states)
self._prepare_inputs(full_hidden_states)
self._propose()
self._update_status()
if self.hybrid_mode:
self._extend_draft_token_with_ngram_match()
@@ -845,3 +885,16 @@ class MTPProposer(Proposer):
def is_chunk_prefill_enabled(self):
""""""
return True
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# In init_attention_metadata, the decode buffer has already been cleared
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.use_cudagraph:
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return

View File

@@ -29,8 +29,8 @@ class NgramProposer(Proposer):
Matching corresponding tokens in input and output as draft tokens.
"""
def __init__(self, cfg: FDConfig):
super().__init__(cfg)
def __init__(self, fd_config: FDConfig):
super().__init__(fd_config)
self.max_ngram_size = self.speculative_config.max_ngram_size
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()

View File

@@ -124,6 +124,7 @@ class GPUModelRunner(ModelRunnerBase):
"matmul_v2",
"fused_gemm_epilogue",
]
# Sampler
if not self.speculative_decoding:
self.sampler = Sampler(fd_config)
@@ -138,8 +139,7 @@ class GPUModelRunner(ModelRunnerBase):
# Lazy initialize kv cache after model loading
# self.kv_caches: list[paddle.Tensor] = []
# Cuda Graph
self.graph_opt_level = self.graph_opt_config.graph_opt_level
# CUDA Graph
self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
@@ -160,7 +160,7 @@ class GPUModelRunner(ModelRunnerBase):
# In the future, we will expand it as a list.
self.attn_backends: list[AttentionBackend] = []
# self.attn_metadatas: list[AttentionMetadata] = []
self.initialize_attn_backend()
self._initialize_attn_backend()
# Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None
@@ -1021,7 +1021,6 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
# NOTE: (changwenbin) Initialized to max_num_seq '-1' before copying, marking illegal positions
self.share_inputs["batch_id_per_token"][:] = -1
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -1035,6 +1034,7 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize forward meta data
self.initialize_forward_meta()
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
# Get sampling metadata
self.sampling_metadata = SamplingMetadata(
@@ -1152,7 +1152,6 @@ class GPUModelRunner(ModelRunnerBase):
# Get kv cache dtype
cache_type = self.parallel_config.dtype
kv_cache_quant_type = None
if (
self.quant_config
@@ -1242,7 +1241,7 @@ class GPUModelRunner(ModelRunnerBase):
paddle.device.cuda.empty_cache()
def initialize_attn_backend(self) -> None:
def _initialize_attn_backend(self) -> None:
"""
Initialize attention backends
"""
@@ -1312,6 +1311,7 @@ class GPUModelRunner(ModelRunnerBase):
expected_decode_len: int = 1,
in_capturing: bool = False,
capture_prefill: bool = False,
accept_all_drafts: bool = False,
) -> paddle.Tensor:
"""
Use dummy inputs to run before formal execution.
@@ -1320,6 +1320,7 @@ class GPUModelRunner(ModelRunnerBase):
expected_decode_len: Expected number of tokens generated
in_capturing: Is cuda graph in capturing state
capture_prefill: Capture pure prefill for cuda graph
accept_all_drafts: Target model will accept all draft tokens
"""
input_length_list, max_dec_len_list, block_num = self.get_input_length_list(
@@ -1339,8 +1340,8 @@ class GPUModelRunner(ModelRunnerBase):
batch_size=batch_size,
expected_decode_len=expected_decode_len,
)
while True:
while True:
# 1. Initialize forward meta and attention meta data
self._prepare_inputs()
@@ -1360,6 +1361,8 @@ class GPUModelRunner(ModelRunnerBase):
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding(
model_output,
@@ -1404,6 +1407,7 @@ class GPUModelRunner(ModelRunnerBase):
self.sampling_metadata,
self.parallel_config.max_model_len,
self.share_inputs,
accept_all_drafts,
)
sampler_output = None
if self.parallel_config.tensor_parallel_size > 1:
@@ -1470,7 +1474,6 @@ class GPUModelRunner(ModelRunnerBase):
skip_save_output=True,
zmq_client=self.zmq_client,
)
if self.speculative_decoding:
if self.speculative_method == "mtp":
self.proposer.run(full_hidden_states=model_output)
@@ -1565,7 +1568,6 @@ class GPUModelRunner(ModelRunnerBase):
time_before_capture = time.perf_counter()
expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy()
if self.fd_config.graph_opt_config.cudagraph_only_prefill:
for num_tokens in sorted(capture_sizes, reverse=True):
self._dummy_run(
@@ -1578,6 +1580,46 @@ class GPUModelRunner(ModelRunnerBase):
logger.info(
f"Warm up the model with the num_tokens:{num_tokens}, expected_decode_len:{expected_decode_len}"
)
elif self.speculative_decoding and self.speculative_method == "mtp":
# Capture Target Model without bsz 1
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture target model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=1,
)
logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}")
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture Draft model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=True,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
# Capture Draft Model with bsz 1
if 1 in capture_sizes:
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=int(1),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=False,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
@@ -1586,9 +1628,7 @@ class GPUModelRunner(ModelRunnerBase):
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(
f"Warm up the model with the num_tokens:{batch_size}, expected_decode_len:{expected_decode_len}"
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
@@ -1674,6 +1714,8 @@ class GPUModelRunner(ModelRunnerBase):
ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta,
)
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding(
model_output,
self.share_inputs["cu_seqlens_q"],
@@ -1872,25 +1914,25 @@ class GPUModelRunner(ModelRunnerBase):
@profile_run_guard(True)
def profile_run(self) -> None:
"""Execute a forward pass with dummy inputs to profile the memory usage of the model"""
# Initialize kv cache for profile run. After profile run kv cache will be reset.
# TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache(profile=True)
if self.speculative_method in ["mtp"]:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
# 1. Profile with multimodal encoder & encoder cache
# 2. Dummy run
self._dummy_run(
num_tokens=self.scheduler_config.max_num_batched_tokens,
batch_size=min(self.scheduler_config.max_num_seqs, 3),
batch_size=self.scheduler_config.max_num_seqs,
)
# 3. gc
self.clear_cache()
if self.speculative_method in ["mtp"]:
self.proposer.clear_dummy_input()
self.proposer.clear_mtp_cache()
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
"""
@@ -1920,7 +1962,7 @@ class GPUModelRunner(ModelRunnerBase):
)
if self.speculative_method in ["mtp"]:
self.proposer.update_block_num(num_gpu_blocks)
self.proposer.update_mtp_block_num(num_gpu_blocks)
def cal_theortical_kvcache(self):
"""
@@ -2017,6 +2059,7 @@ class GPUModelRunner(ModelRunnerBase):
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.use_cudagraph:
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return
def _init_image_preprocess(self) -> None:

View File

@@ -207,7 +207,7 @@ class GpuWorker(WorkerBase):
"""
Perform the warm-up and the graph optimization
"""
if self.model_runner.graph_opt_level >= 1:
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
self.model_runner.sot_warmup()
# Trigger cuda graph capture
self.model_runner.capture_model()

View File

@@ -1309,7 +1309,7 @@ class HPUModelRunner(ModelRunnerBase):
accept_num=self.share_inputs["accept_num"] if self.speculative_decoding else None,
)
# if self.speculative_config.method in ["mtp"] and self.parallel_config.splitwise_role == "prefill":
# if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
# skip_save_output = True
# else:
# skip_save_output = False