[Executor]CUDAGraph support Speculate Decode (#4258)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* [Executor]CUDAGraph support Speculate Decode

* fix problem

* solve problem

* fix

* fast compile

* CUDAGraph + mtp support eb5(only target model)

* Revert "fast compile"

This reverts commit 3cfe8373ed.

* fix precommit

* solve comment

* fix comment about #pragram unroll

---------

Co-authored-by: gongshaotian <gstain5555@outlook.com>
Co-authored-by: gongshaotian <gstian5555@outlook.com>
This commit is contained in:
Jundong Liu
2025-10-13 15:21:41 +08:00
committed by GitHub
parent 07db281647
commit 0b7a5778ab
16 changed files with 265 additions and 134 deletions

View File

@@ -2410,6 +2410,9 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2]; __shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid]; const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid]; const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue; if (seq_len_q == 0) continue;
@@ -2427,7 +2430,7 @@ __global__ void merge_multi_chunks_v2_kernel(
seq_len_kv += seq_len_q; seq_len_kv += seq_len_q;
} }
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) { if (num_chunks_this_seq <= 1 || !ENABLE_PREFILL) {
continue; continue;
} }

View File

@@ -72,6 +72,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize; int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
const int token_id = linear_index / hidden_size; const int token_id = linear_index / hidden_size;
const int ori_bi = batch_id_per_token[token_id]; const int ori_bi = batch_id_per_token[token_id];
if (ori_bi == -1) continue;
if (seq_lens_decoder[ori_bi] == 0) continue; if (seq_lens_decoder[ori_bi] == 0) continue;
const int bias = linear_index % hidden_size; const int bias = linear_index % hidden_size;
const int hi = bias / head_size; // q + k + v const int hi = bias / head_size; // q + k + v
@@ -84,15 +85,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size]; const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) { if (block_idx < 0) {
printf( return; // NOTE(gongshaotian): For CUDAGraph padding
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
} }
const int block_offset = write_seq_id % block_size; const int block_offset = write_seq_id % block_size;
@@ -149,13 +142,13 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
float row_inv_var = Rsqrt(row_variance + rms_norm_eps); float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < num_heads) { if (hi < num_heads) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec); Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]); bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
} }
} else { } else {
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec); Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
#pragma unroll #pragma unroll
for (int i = 0; i < VecSize; i++) { for (int i = 0; i < VecSize; i++) {
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]); bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
} }
@@ -390,15 +383,7 @@ __global__ void append_speculate_cache_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size]; const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) { if (block_idx < 0) {
printf( return; // NOTE(gongshaotian): For CUDAGraph padding
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
} }
const int block_offset = write_seq_id % block_size; const int block_offset = write_seq_id % block_size;
@@ -525,15 +510,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq; const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size]; const int block_idx = block_table_now[write_seq_id / block_size];
if (block_idx < 0) { if (block_idx < 0) {
printf( return; // NOTE(gongshaotian): For CUDAGraph padding
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
"%d %d %d %d\n",
block_idx,
write_seq_id,
ori_bi,
seq_lens_decoder[ori_bi],
token_id,
cu_seqlens_q[ori_bi]);
} }
const int block_offset = write_seq_id % block_size; const int block_offset = write_seq_id % block_size;

View File

@@ -676,7 +676,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets, const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode); int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts);
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder, void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,

View File

@@ -139,6 +139,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
PD_BUILD_STATIC_OP(speculate_get_padding_offset) PD_BUILD_STATIC_OP(speculate_get_padding_offset)
.Inputs({"input_ids", .Inputs({"input_ids",
"draft_tokens", "draft_tokens",
"cum_offsets",
"token_num", "token_num",
"seq_len", "seq_len",
"seq_lens_encoder"}) "seq_lens_encoder"})

View File

@@ -79,7 +79,7 @@ __global__ void speculate_verify(
const int *output_cum_offsets, const int *actual_candidate_len, const int *output_cum_offsets, const int *actual_candidate_len,
const int real_bsz, const int max_draft_tokens, const int end_length, const int real_bsz, const int max_draft_tokens, const int end_length,
const int max_seq_len, const int max_candidate_len, const int verify_window, const int max_seq_len, const int max_candidate_len, const int verify_window,
const bool prefill_one_step_stop, const bool benchmark_mode) { const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) {
const int bid = threadIdx.x; const int bid = threadIdx.x;
// verify and set stop flags // verify and set stop flags
int accept_num_now = 1; int accept_num_now = 1;
@@ -107,6 +107,24 @@ __global__ void speculate_verify(
if (seq_lens_encoder[bid] != 0) { if (seq_lens_encoder[bid] != 0) {
break; break;
} }
if (accept_all_drafts) {
// accept all draft tokens
step_idx[bid]++;
auto accept_token = draft_tokens_now[i + 1];
accept_tokens[bid * max_draft_tokens + i] = accept_token;
if (is_in_end(accept_token, end_tokens, end_length) ||
step_idx[bid] >= max_dec_len[bid]) {
stop_flags[bid] = true;
stop_flag_now_int = 1;
if (step_idx[bid] >= max_dec_len[bid])
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
break;
} else {
accept_num_now++;
}
continue;
}
if (USE_TOPK) { if (USE_TOPK) {
if (verify_tokens_now[i * max_candidate_len] == if (verify_tokens_now[i * max_candidate_len] ==
draft_tokens_now[i + 1]) { draft_tokens_now[i + 1]) {
@@ -255,7 +273,7 @@ void SpeculateVerify(
const paddle::Tensor &output_cum_offsets, const paddle::Tensor &output_cum_offsets,
const paddle::Tensor &actual_candidate_len, const paddle::Tensor &actual_candidate_len,
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp, const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) { int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) {
// printf("Enter speculate update\n"); // printf("Enter speculate update\n");
auto bsz = accept_tokens.shape()[0]; auto bsz = accept_tokens.shape()[0];
int real_bsz = seq_lens_this_time.shape()[0]; int real_bsz = seq_lens_this_time.shape()[0];
@@ -298,7 +316,7 @@ void SpeculateVerify(
is_block_step.data<bool>(), output_cum_offsets.data<int>(), is_block_step.data<bool>(), output_cum_offsets.data<int>(),
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens, actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
end_length, max_seq_len, max_candidate_len, verify_window, end_length, max_seq_len, max_candidate_len, verify_window,
prefill_one_step_stop, benchmark_mode); prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else { } else {
speculate_verify<false, true> speculate_verify<false, true>
<<<1, BlockSize, 0, accept_tokens.stream()>>>( <<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -314,7 +332,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(), end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(), output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len, real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} }
} else { } else {
if (enable_topp) { if (enable_topp) {
@@ -332,7 +350,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(), end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(), output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len, real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} else { } else {
speculate_verify<false, false> speculate_verify<false, false>
<<<1, BlockSize, 0, accept_tokens.stream()>>>( <<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -348,7 +366,7 @@ void SpeculateVerify(
end_tokens.data<int64_t>(), is_block_step.data<bool>(), end_tokens.data<int64_t>(), is_block_step.data<bool>(),
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(), output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
real_bsz, max_draft_tokens, end_length, max_seq_len, real_bsz, max_draft_tokens, end_length, max_seq_len,
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode); max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
} }
} }
@@ -363,7 +381,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
"actual_candidate_len", "actual_draft_token_nums", "topp"}) "actual_candidate_len", "actual_draft_token_nums", "topp"})
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out", .Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
"stop_flags_out"}) "stop_flags_out"})
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"}) .Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"}, .SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"accept_num", "accept_num_out"}, {"accept_num", "accept_num_out"},
{"step_idx", "step_idx_out"}, {"step_idx", "step_idx_out"},

View File

@@ -1150,7 +1150,14 @@ class FDConfig:
# Initialize cuda graph capture list # Initialize cuda graph capture list
if self.graph_opt_config.cudagraph_capture_sizes is None: if self.graph_opt_config.cudagraph_capture_sizes is None:
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs) self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
if self.speculative_config is not None and self.speculative_config.method == "mtp":
max_shape = self.parallel_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
if max_shape % 2 == 1:
max_shape = max_shape + 1
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=min(512, max_shape))
else:
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn # TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
if self.graph_opt_config.graph_opt_level == 2: if self.graph_opt_config.graph_opt_level == 2:

View File

@@ -133,6 +133,7 @@ class ForwardMeta:
"shape": obj.shape, "shape": obj.shape,
"dtype": str(obj.dtype), "dtype": str(obj.dtype),
"place": str(obj.place), "place": str(obj.place),
"content": obj if obj.numel() < 70 else "Too big to show",
} }
return tensor_info return tensor_info
elif isinstance(obj, (list, tuple)): elif isinstance(obj, (list, tuple)):

View File

@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
""" """
import time
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
@@ -111,7 +112,7 @@ class CudaGraphPiecewiseBackend:
entry.num_finished_warmup += 1 entry.num_finished_warmup += 1
entry.runnable(**kwargs) entry.runnable(**kwargs)
logger.debug( logger.debug(
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, " f"[CUDA GRAPH] [ID:{id(self)}] Warm up for batch size {entry.real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times" f"finished ({n + 1}/{entry.num_finished_warmup}) times"
) )
@@ -138,15 +139,17 @@ class CudaGraphPiecewiseBackend:
real_shape = ids_remove_padding.shape[0] real_shape = ids_remove_padding.shape[0]
padding_real_shape = self.real_shape_to_captured_size[real_shape] padding_real_shape = self.real_shape_to_captured_size[real_shape]
logger.debug( logger.debug(
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, " f"[CUDA GRAPH] [ID:{id(self)}] The actual real shape obtained by CUDAGraph is :{real_shape}, "
f"The padded shape is :{padding_real_shape}" f"The padded shape is :{padding_real_shape}, If Padding :{real_shape != padding_real_shape}"
) )
entry = self.concrete_size_entries.get(padding_real_shape) entry = self.concrete_size_entries.get(padding_real_shape)
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list." assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
if entry.runnable is None: if entry.runnable is None:
entry.runnable = self.runnable entry.runnable = self.runnable
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}") logger.debug(
f"[CUDA GRAPH] [ID:{id(self)}] New entry lazy initialize with real shape {padding_real_shape}"
)
if not entry.use_cudagraph: if not entry.use_cudagraph:
return entry.runnable(**kwargs) return entry.runnable(**kwargs)
@@ -161,7 +164,7 @@ class CudaGraphPiecewiseBackend:
entry.num_finished_warmup += 1 entry.num_finished_warmup += 1
entry.runnable(**kwargs) entry.runnable(**kwargs)
logger.debug( logger.debug(
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, " f"[CUDA GRAPH] [ID:{id(self)}] Warm up for real shape {padding_real_shape}, "
f"finished ({n + 1}/{entry.num_finished_warmup}) times" f"finished ({n + 1}/{entry.num_finished_warmup}) times"
) )
@@ -196,11 +199,11 @@ class CudaGraphPiecewiseBackend:
# For CUDAGraph debug # For CUDAGraph debug
# self._save_cudagrpah_dot_files(entry) # self._save_cudagrpah_dot_files(entry)
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}") logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}")
# Replay # Replay
entry.cuda_graph.replay() entry.cuda_graph.replay()
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}") logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph replayed for real shape {padding_real_shape}")
if len(entry.output_buffers) == 1: if len(entry.output_buffers) == 1:
return entry.output_buffers[0] return entry.output_buffers[0]
return entry.output_buffers return entry.output_buffers
@@ -214,16 +217,17 @@ class CudaGraphPiecewiseBackend:
self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape) self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape)
logger.info( logger.info(
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry." f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
"Created all real shape entry."
) )
def clear_graph(self): def clear_graph(self):
""" """ """ """
# Clear graphs # Clear graphs
for id, entry in self.concrete_size_entries.items(): for _id, entry in self.concrete_size_entries.items():
if entry.cuda_graph: if entry.cuda_graph:
del entry.cuda_graph del entry.cuda_graph
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.") logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] The CUDAGraph with shape {_id} has been cleared.")
del self.concrete_size_entries del self.concrete_size_entries
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
@@ -236,6 +240,6 @@ class CudaGraphPiecewiseBackend:
log_dir = envs.FD_LOG_DIR log_dir = envs.FD_LOG_DIR
if entry.cuda_graph: if entry.cuda_graph:
entry.cuda_graph.print_to_dot_files( entry.cuda_graph.print_to_dot_files(
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}", f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}_time{time.perf_counter()}",
1 << 0, 1 << 0,
) )

View File

@@ -115,7 +115,7 @@ class GraphOptBackend:
self.runnable = runnable self.runnable = runnable
self.fd_config = fd_config self.fd_config = fd_config
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0] self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
if self.fd_config.graph_opt_config.graph_opt_level > 0: if self.fd_config.graph_opt_config.graph_opt_level > 0:
# 1. Prepare cuda grpah input buffers (contain output of subgraphs) # 1. Prepare cuda grpah input buffers (contain output of subgraphs)
@@ -138,9 +138,9 @@ class GraphOptBackend:
) )
assert kwargs["forward_meta"].ids_remove_padding is not None assert kwargs["forward_meta"].ids_remove_padding is not None
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0] real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0]
if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch): if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size):
return self.runnable(**kwargs) return self.runnable(**kwargs)
else: else:
return self.cudagraph_piecewise_backend.__call__(**kwargs) return self.cudagraph_piecewise_backend.__call__(**kwargs)

View File

@@ -395,6 +395,7 @@ class SpeculativeSampler(nn.Layer):
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
max_model_len: int, max_model_len: int,
share_inputs: List[paddle.Tensor], share_inputs: List[paddle.Tensor],
accept_all_drafts: bool = False,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """ """ """
@@ -451,6 +452,7 @@ class SpeculativeSampler(nn.Layer):
self.speculative_verify_window, self.speculative_verify_window,
True, # enable_topp True, # enable_topp
self.speculative_benchmark_mode, self.speculative_benchmark_mode,
accept_all_drafts,
) )
return None return None

View File

@@ -27,6 +27,9 @@ from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig from fastdeploy.config import FDConfig
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.decorator import (
support_graph_optimization,
)
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.layers.normalization import RMSNorm
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
@@ -229,6 +232,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
return mappings return mappings
@support_graph_optimization
class Ernie4_5_MTPModel(nn.Layer): class Ernie4_5_MTPModel(nn.Layer):
""" """
Ernie4_5_MTPModel Ernie4_5_MTPModel
@@ -435,6 +439,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
""" """
forward forward
""" """
hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta) hidden_states = self.ernie(
ids_remove_padding=ids_remove_padding,
previous_hidden_states=previous_hidden_states,
forward_meta=forward_meta,
)
return hidden_states return hidden_states

View File

@@ -33,31 +33,33 @@ class Proposer(ABC):
the speculative decoding framework the speculative decoding framework
""" """
def __init__(self, cfg: FDConfig): def __init__(self, fd_config: FDConfig):
""" """
Init Speculative proposer Init Speculative proposer
""" """
cfg.parallel_config.tp_group = None fd_config.parallel_config.tp_group = None
cfg.parallel_config.ep_group = None fd_config.parallel_config.ep_group = None
self.cfg = deepcopy(cfg) self.fd_config = deepcopy(fd_config)
cfg.parallel_config.tp_group = dist.get_group( fd_config.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
) )
cfg.parallel_config.ep_group = dist.get_group( fd_config.parallel_config.ep_group = dist.get_group(
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
) )
self.cfg.parallel_config.tp_group = dist.get_group( self.fd_config.parallel_config.tp_group = dist.get_group(
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
) )
self.cfg.parallel_config.ep_group = dist.get_group( self.fd_config.parallel_config.ep_group = dist.get_group(
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
) )
self.parallel_config = self.cfg.parallel_config self.parallel_config = self.fd_config.parallel_config
self.model_config = self.cfg.model_config self.model_config = self.fd_config.model_config
self.speculative_config = self.cfg.speculative_config self.speculative_config = self.fd_config.speculative_config
self.cache_config = self.cfg.cache_config self.cache_config = self.fd_config.cache_config
self.quant_config = self.cfg.quant_config self.quant_config = self.fd_config.quant_config
self.graph_opt_config = self.fd_config.graph_opt_config
self.scheduler_config = self.fd_config.scheduler_config
self.max_num_seqs = self.parallel_config.max_num_seqs self.max_num_seqs = self.parallel_config.max_num_seqs
self.max_model_len = self.parallel_config.max_model_len self.max_model_len = self.parallel_config.max_model_len

View File

@@ -22,6 +22,7 @@ import paddle
from paddleformers.utils.log import logger from paddleformers.utils.log import logger
from fastdeploy import envs from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import Request, RequestType from fastdeploy.engine.request import Request, RequestType
from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention import get_attention_backend
@@ -31,6 +32,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.layers.rotary_embedding import get_rope
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
from fastdeploy.model_executor.model_loader import get_model_loader
from fastdeploy.model_executor.models import ModelForCasualLM
from fastdeploy.model_executor.ops.gpu import ( from fastdeploy.model_executor.ops.gpu import (
draft_model_postprocess, draft_model_postprocess,
draft_model_preprocess, draft_model_preprocess,
@@ -52,12 +55,19 @@ class MTPProposer(Proposer):
Proposer for Multi-Token-Prediction(MTP) Proposer for Multi-Token-Prediction(MTP)
""" """
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs): def __init__(
super().__init__(cfg) self,
fd_config: FDConfig,
main_model: ModelForCasualLM,
local_rank: int,
device_id: int, # physical device id
target_model_inputs, # main model share inputs
):
super().__init__(fd_config)
self.num_main_model_layers = self.model_config.num_hidden_layers self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank self.local_rank = local_rank
self.device_id = device_id self.device_id = device_id
self._update_cfg(main_model) self._update_mtp_config(main_model)
self._load_model() self._load_model()
self.target_model_inputs = target_model_inputs self.target_model_inputs = target_model_inputs
self.mtp_strategy = self.speculative_config.mtp_strategy self.mtp_strategy = self.speculative_config.mtp_strategy
@@ -65,16 +75,22 @@ class MTPProposer(Proposer):
# [mixed, prefill, decoder] # [mixed, prefill, decoder]
self.role = "mixed" self.role = "mixed"
self.sampler = MTPSampler(cfg) self.sampler = MTPSampler(fd_config)
self._init_model_inputs() self._init_model_inputs()
# CUDA Graph
self.use_cudagraph = False # self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
self.attn_backends: list[AttentionBackend] = [] self.attn_backends: list[AttentionBackend] = []
self._initialize_attn_backend() self._initialize_attn_backend()
def _update_cfg(self, main_model): def _update_mtp_config(self, main_model):
""" """
Update config for MTP from global config Update config for MTP from global config
""" """
self.forward_meta: ForwardMeta = None
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP") self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
self.speculative_config.sharing_model = main_model self.speculative_config.sharing_model = main_model
self.model_config.num_hidden_layers = 1 self.model_config.num_hidden_layers = 1
@@ -89,21 +105,18 @@ class MTPProposer(Proposer):
""" """
Load MTP Layer Load MTP Layer
""" """
from fastdeploy.model_executor.model_loader import get_model_loader
model_loader = get_model_loader(load_config=self.cfg.load_config) model_loader = get_model_loader(load_config=self.fd_config.load_config)
self.model = model_loader.load_model(fd_config=self.cfg) self.model = model_loader.load_model(fd_config=self.fd_config)
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int): def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
"""Set dummy prefill inputs to model_inputs""" """Set dummy prefill inputs to model_inputs"""
max_dec_len = expected_decode_len + 1 max_dec_len = expected_decode_len + 1
self.num_gpu_blocks = self.parallel_config.total_block_num input_length = min(
self.initialize_kv_cache()
full_length = min(
num_tokens // batch_size, num_tokens // batch_size,
self.parallel_config.max_model_len - max_dec_len, self.parallel_config.max_model_len - max_dec_len,
) )
input_length = int(full_length * self.cache_config.kv_cache_ratio)
block_num = ( block_num = (
input_length + self.cache_config.block_size - 1 input_length + self.cache_config.block_size - 1
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
@@ -125,13 +138,15 @@ class MTPProposer(Proposer):
) )
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def initialize_kv_cache(self): def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
""" """
Initialize kv cache Initialize kv cache
""" """
# prompt cache
self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio)
self.cache_kvs = {} self.cache_kvs = {}
# Get kv cache dtype
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
kv_cache_quant_type = None kv_cache_quant_type = None
@@ -151,9 +166,7 @@ class MTPProposer(Proposer):
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not self.parallel_config.do_profile and ( if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
):
cache_kvs_list = [] cache_kvs_list = []
for i in range( for i in range(
self.num_main_model_layers, self.num_main_model_layers,
@@ -230,7 +243,7 @@ class MTPProposer(Proposer):
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend() attn_cls = get_attention_backend()
attn_backend = attn_cls( attn_backend = attn_cls(
self.cfg, self.fd_config,
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim, head_dim=head_dim,
@@ -243,7 +256,7 @@ class MTPProposer(Proposer):
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
def clear_dummy_input(self): def clear_mtp_cache(self):
""" """
Clear allocated cacheKV Clear allocated cacheKV
""" """
@@ -251,15 +264,14 @@ class MTPProposer(Proposer):
if self.forward_meta is not None: if self.forward_meta is not None:
del self.forward_meta.caches del self.forward_meta.caches
def update_block_num(self, num_gpu_blocks) -> None: def update_mtp_block_num(self, num_gpu_blocks) -> None:
""" """
Update block num by theoretical calculation Update MTP block num by theoretical calculation
""" """
# Reset block table and kv cache with global block num
self.main_model_num_gpu_blocks = num_gpu_blocks self.main_model_num_gpu_blocks = num_gpu_blocks
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio) self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
self.initialize_kv_cache()
# Reset free list # Reset free list
free_list = list( free_list = list(
@@ -276,7 +288,6 @@ class MTPProposer(Proposer):
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"), "free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
} }
) )
self.parallel_config.do_profile = False
def _init_model_inputs(self): def _init_model_inputs(self):
""" """
@@ -300,6 +311,8 @@ class MTPProposer(Proposer):
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"]) self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu") self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"]) self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"])
self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"])
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"]) self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"]) self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"]) self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
@@ -308,6 +321,9 @@ class MTPProposer(Proposer):
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone( self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
self.target_model_inputs["decoder_tile_ids_per_batch"] self.target_model_inputs["decoder_tile_ids_per_batch"]
) )
self.model_inputs["target_hidden_states"] = paddle.full(
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
)
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
self.model_inputs["rope_emb"] = get_rope( self.model_inputs["rope_emb"] = get_rope(
@@ -443,9 +459,6 @@ class MTPProposer(Proposer):
""" """
Process inputs for prefill tasks and insert it to model_inputs buffer Process inputs for prefill tasks and insert it to model_inputs buffer
""" """
# NOTE: Lazy initialize kv cache
if "caches" not in self.model_inputs:
self.initialize_kv_cache()
# TODO:Init role in initialize process # TODO:Init role in initialize process
if req_dicts[-1].disaggregate_info is not None: if req_dicts[-1].disaggregate_info is not None:
@@ -526,7 +539,7 @@ class MTPProposer(Proposer):
request.get("block_tables"), dtype="int32" request.get("block_tables"), dtype="int32"
) )
self.model_inputs["not_need_stop"][0] = True self.model_inputs["not_need_stop"][0] = True
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
def _initialize_forward_meta(self): def _initialize_forward_meta(self):
""" """
@@ -556,6 +569,33 @@ class MTPProposer(Proposer):
for attn_backend in self.attn_backends: for attn_backend in self.attn_backends:
attn_backend.init_attention_metadata(self.forward_meta) attn_backend.init_attention_metadata(self.forward_meta)
# Update Batch type for cuda graph
only_decode_batch = True
prefill_exists = None
# Mix ep in single node
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
only_decode_batch_list = []
prefill_exists = self.exist_prefill()
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
only_decode_batch = all(only_decode_batch_list)
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
self.forward_meta.step_use_cudagraph = (
self.use_cudagraph
and only_decode_batch
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
)
def exist_prefill(self):
"""
check whether prefill stage exist
"""
if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0:
return 1
else:
return 0
def _prepare_inputs(self, full_hidden_states): def _prepare_inputs(self, full_hidden_states):
""" """
Prepare MTP inputs Prepare MTP inputs
@@ -599,10 +639,8 @@ class MTPProposer(Proposer):
self.target_model_inputs["seq_lens_encoder"], self.target_model_inputs["seq_lens_encoder"],
self.num_model_steps, self.num_model_steps,
) )
if isinstance(target_hidden_states, list):
target_hidden_states = target_hidden_states[0]
return target_hidden_states self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
def _post_process(self, sampled_token_ids): def _post_process(self, sampled_token_ids):
""" """
@@ -633,7 +671,7 @@ class MTPProposer(Proposer):
self.parallel_config.use_ep, self.parallel_config.use_ep,
) )
def _propose(self, target_hidden_states): def _propose(self):
""" """
Main process for MTP inference Main process for MTP inference
""" """
@@ -663,10 +701,15 @@ class MTPProposer(Proposer):
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
# for speculative decoding # for speculative decoding
self.model_inputs["output_cum_offsets"] = output_cum_offsets self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
self.model_inputs["output_padding_offset"] = output_padding_offset self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
# Initialize forward meta data
self._initialize_forward_meta() self._initialize_forward_meta()
# Padding inputs for cuda graph
self.padding_cudagraph_inputs()
# Get sampling metadata # Get sampling metadata
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
temperature=self.model_inputs["temperature"], temperature=self.model_inputs["temperature"],
@@ -687,9 +730,11 @@ class MTPProposer(Proposer):
model_output = self.model( model_output = self.model(
ids_remove_padding=self.model_inputs["ids_remove_padding"], ids_remove_padding=self.model_inputs["ids_remove_padding"],
previous_hidden_states=target_hidden_states, previous_hidden_states=self.model_inputs["target_hidden_states"],
forward_meta=self.forward_meta, forward_meta=self.forward_meta,
) )
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
@@ -721,7 +766,7 @@ class MTPProposer(Proposer):
self._post_process(sampled_token_ids) self._post_process(sampled_token_ids)
if substep != self.num_model_steps - 1: if substep != self.num_model_steps - 1:
target_hidden_states = self._get_self_hidden_states(hidden_states) self._get_self_hidden_states(hidden_states)
else: else:
if hasattr(self.model, "empty_input_forward"): if hasattr(self.model, "empty_input_forward"):
self.model.empty_input_forward() self.model.empty_input_forward()
@@ -733,10 +778,7 @@ class MTPProposer(Proposer):
self.model_inputs["seq_lens_this_time"], self.model_inputs["seq_lens_this_time"],
self.model_inputs["step_idx"], self.model_inputs["step_idx"],
) )
if isinstance(target_hidden_states, list): self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
target_hidden_states = target_hidden_states[0]
return target_hidden_states
def update_task_chunk_prefill(self, task): def update_task_chunk_prefill(self, task):
""" """
@@ -821,8 +863,8 @@ class MTPProposer(Proposer):
def _run_impl(self, full_hidden_states): def _run_impl(self, full_hidden_states):
"""""" """"""
target_hidden_states = self._prepare_inputs(full_hidden_states) self._prepare_inputs(full_hidden_states)
self._propose(target_hidden_states=target_hidden_states) self._propose()
self._update_status() self._update_status()
if self.hybrid_mode: if self.hybrid_mode:
self._extend_draft_token_with_ngram_match() self._extend_draft_token_with_ngram_match()
@@ -830,3 +872,16 @@ class MTPProposer(Proposer):
def is_chunk_prefill_enabled(self): def is_chunk_prefill_enabled(self):
"""""" """"""
return True return True
def padding_cudagraph_inputs(self) -> None:
"""
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# In init_attention_metadata, the decode buffer has already been cleared
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.use_cudagraph:
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return

View File

@@ -29,8 +29,8 @@ class NgramProposer(Proposer):
Matching corresponding tokens in input and output as draft tokens. Matching corresponding tokens in input and output as draft tokens.
""" """
def __init__(self, cfg: FDConfig): def __init__(self, fd_config: FDConfig):
super().__init__(cfg) super().__init__(fd_config)
self.max_ngram_size = self.speculative_config.max_ngram_size self.max_ngram_size = self.speculative_config.max_ngram_size
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()

View File

@@ -135,7 +135,6 @@ class GPUModelRunner(ModelRunnerBase):
# self.kv_caches: list[paddle.Tensor] = [] # self.kv_caches: list[paddle.Tensor] = []
# Cuda Graph # Cuda Graph
self.graph_opt_level = self.graph_opt_config.graph_opt_level
self.use_cudagraph = self.graph_opt_config.use_cudagraph self.use_cudagraph = self.graph_opt_config.use_cudagraph
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes)) self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
@@ -155,7 +154,7 @@ class GPUModelRunner(ModelRunnerBase):
# In the future, we will expand it as a list. # In the future, we will expand it as a list.
self.attn_backends: list[AttentionBackend] = [] self.attn_backends: list[AttentionBackend] = []
# self.attn_metadatas: list[AttentionMetadata] = [] # self.attn_metadatas: list[AttentionMetadata] = []
self.initialize_attn_backend() self._initialize_attn_backend()
# Forward meta store the global meta information of the forward # Forward meta store the global meta information of the forward
self.forward_meta: ForwardMeta = None self.forward_meta: ForwardMeta = None
@@ -876,7 +875,6 @@ class GPUModelRunner(ModelRunnerBase):
) )
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False) self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False) self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False) self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
@@ -890,6 +888,7 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize forward meta data # Initialize forward meta data
self.initialize_forward_meta() self.initialize_forward_meta()
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
# Get sampling metadata # Get sampling metadata
self.sampling_metadata = SamplingMetadata( self.sampling_metadata = SamplingMetadata(
@@ -992,7 +991,6 @@ class GPUModelRunner(ModelRunnerBase):
Initialize kv cache Initialize kv cache
""" """
cache_kvs = {} cache_kvs = {}
max_block_num = self.num_gpu_blocks
# Get kv cache dtype # Get kv cache dtype
cache_type = self.parallel_config.dtype cache_type = self.parallel_config.dtype
@@ -1008,7 +1006,7 @@ class GPUModelRunner(ModelRunnerBase):
# Get kv cache shape # Get kv cache shape
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape( kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
) )
if kv_cache_quant_type == "block_wise_fp8": if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]] kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
@@ -1055,7 +1053,7 @@ class GPUModelRunner(ModelRunnerBase):
del value del value
paddle.device.cuda.empty_cache() paddle.device.cuda.empty_cache()
def initialize_attn_backend(self) -> None: def _initialize_attn_backend(self) -> None:
""" """
Initialize attention backends Initialize attention backends
""" """
@@ -1099,6 +1097,7 @@ class GPUModelRunner(ModelRunnerBase):
batch_size: paddle.Tensor, batch_size: paddle.Tensor,
expected_decode_len: int = 1, expected_decode_len: int = 1,
in_capturing: bool = False, in_capturing: bool = False,
accept_all_drafts: bool = False,
) -> paddle.Tensor: ) -> paddle.Tensor:
""" """
Use dummy inputs to run before formal execution. Use dummy inputs to run before formal execution.
@@ -1106,6 +1105,7 @@ class GPUModelRunner(ModelRunnerBase):
num_tokens: num_tokens:
expected_decode_len: Expected number of tokens generated expected_decode_len: Expected number of tokens generated
in_capturing: Is cuda graph in capturing state in_capturing: Is cuda graph in capturing state
accept_all_drafts: Target model will accept all draft tokens
""" """
self._dummy_prefill_inputs( self._dummy_prefill_inputs(
num_tokens=num_tokens, num_tokens=num_tokens,
@@ -1134,12 +1134,16 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["image_features"], self.share_inputs["image_features"],
self.forward_meta, self.forward_meta,
) )
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = model_output hidden_states = model_output
else: else:
model_output = self.model( model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"], ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta, forward_meta=self.forward_meta,
) )
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
@@ -1179,6 +1183,7 @@ class GPUModelRunner(ModelRunnerBase):
self.sampling_metadata, self.sampling_metadata,
self.parallel_config.max_model_len, self.parallel_config.max_model_len,
self.share_inputs, self.share_inputs,
accept_all_drafts,
) )
sampler_output = None sampler_output = None
if self.parallel_config.tensor_parallel_size > 1: if self.parallel_config.tensor_parallel_size > 1:
@@ -1339,14 +1344,55 @@ class GPUModelRunner(ModelRunnerBase):
time_before_capture = time.perf_counter() time_before_capture = time.perf_counter()
expected_decode_len = 1 expected_decode_len = 1
capture_sizes = self.cudagraph_capture_sizes.copy() capture_sizes = self.cudagraph_capture_sizes.copy()
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run( if self.speculative_decoding and self.speculative_method == "mtp":
num_tokens=self.parallel_config.max_num_batched_tokens, # Capture Target Model without bsz 1
batch_size=batch_size, for batch_size in sorted(capture_sizes, reverse=True):
in_capturing=True, if batch_size == 1:
expected_decode_len=expected_decode_len, logger.info("Skip token_num = 1, when capture target model for mtp")
) else:
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}") assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=1,
)
logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}")
# Capture Draft Model without bsz 1
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
for batch_size in sorted(capture_sizes, reverse=True):
if batch_size == 1:
logger.info("Skip token_num = 1, when capture Draft model for mtp")
else:
assert batch_size % 2 == 0
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=int(batch_size / 2),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=True,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
# Capture Draft Model with bsz 1
if 1 in capture_sizes:
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=int(1),
in_capturing=True,
expected_decode_len=3,
accept_all_drafts=False,
)
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
else:
for batch_size in sorted(capture_sizes, reverse=True):
self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=batch_size,
in_capturing=True,
expected_decode_len=expected_decode_len,
)
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
time_after_capture = time.perf_counter() time_after_capture = time.perf_counter()
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds") logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
@@ -1427,12 +1473,16 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["image_features"], self.share_inputs["image_features"],
self.forward_meta, self.forward_meta,
) )
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = model_output hidden_states = model_output
else: else:
model_output = self.model( model_output = self.model(
ids_remove_padding=self.share_inputs["ids_remove_padding"], ids_remove_padding=self.share_inputs["ids_remove_padding"],
forward_meta=self.forward_meta, forward_meta=self.forward_meta,
) )
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]
hidden_states = rebuild_padding( hidden_states = rebuild_padding(
model_output, model_output,
self.share_inputs["cu_seqlens_q"], self.share_inputs["cu_seqlens_q"],
@@ -1628,20 +1678,22 @@ class GPUModelRunner(ModelRunnerBase):
# TODO(gongshaotian): Optimize the management logic of kvcache # TODO(gongshaotian): Optimize the management logic of kvcache
self.num_gpu_blocks = self.parallel_config.total_block_num self.num_gpu_blocks = self.parallel_config.total_block_num
self.initialize_kv_cache(profile=True) self.initialize_kv_cache(profile=True)
if self.speculative_method in ["mtp"]:
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
# 1. Profile with multimodal encoder & encoder cache # 1. Profile with multimodal encoder & encoder cache
# 2. Dummy run # 2. Dummy run
self._dummy_run( self._dummy_run(
num_tokens=self.parallel_config.max_num_batched_tokens, num_tokens=self.parallel_config.max_num_batched_tokens,
batch_size=min(self.parallel_config.max_num_seqs, 3), batch_size=self.parallel_config.max_num_seqs,
) )
# 3. gc # 3. gc
self.clear_cache() self.clear_cache()
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.clear_dummy_input() self.proposer.clear_mtp_cache()
def update_share_input_block_num(self, num_gpu_blocks: int) -> None: def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
""" """
@@ -1671,7 +1723,7 @@ class GPUModelRunner(ModelRunnerBase):
) )
if self.speculative_method in ["mtp"]: if self.speculative_method in ["mtp"]:
self.proposer.update_block_num(num_gpu_blocks) self.proposer.update_mtp_block_num(num_gpu_blocks)
def cal_theortical_kvcache(self): def cal_theortical_kvcache(self):
""" """
@@ -1756,6 +1808,7 @@ class GPUModelRunner(ModelRunnerBase):
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size. # To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
if self.use_cudagraph: if self.use_cudagraph:
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
return return
def _init_image_preprocess(self) -> None: def _init_image_preprocess(self) -> None:

View File

@@ -207,7 +207,7 @@ class GpuWorker(WorkerBase):
""" """
Perform the warm-up and the graph optimization Perform the warm-up and the graph optimization
""" """
if self.model_runner.graph_opt_level >= 1: if self.fd_config.graph_opt_config.graph_opt_level >= 1:
self.model_runner.sot_warmup() self.model_runner.sot_warmup()
# Triger cuda grpah capture # Triger cuda grpah capture
self.model_runner.capture_model() self.model_runner.capture_model()