mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-27 18:41:02 +08:00
[Executor]CUDAGraph support Speculate Decode (#4258)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
* [Executor]CUDAGraph support Speculate Decode
* fix problem
* solve problem
* fix
* fast compile
* CUDAGraph + mtp support eb5(only target model)
* Revert "fast compile"
This reverts commit 3cfe8373ed.
* fix precommit
* solve comment
* fix comment about #pragram unroll
---------
Co-authored-by: gongshaotian <gstain5555@outlook.com>
Co-authored-by: gongshaotian <gstian5555@outlook.com>
This commit is contained in:
@@ -2410,6 +2410,9 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
__shared__ float md_smem[bdy * 2];
|
__shared__ float md_smem[bdy * 2];
|
||||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||||
const uint32_t bid = batch_id_per_token[qid];
|
const uint32_t bid = batch_id_per_token[qid];
|
||||||
|
if(bid == -1){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||||
const int seq_len_q = seq_lens_q[bid];
|
const int seq_len_q = seq_lens_q[bid];
|
||||||
if (seq_len_q == 0) continue;
|
if (seq_len_q == 0) continue;
|
||||||
@@ -2427,7 +2430,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
seq_len_kv += seq_len_q;
|
seq_len_kv += seq_len_q;
|
||||||
}
|
}
|
||||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||||
if (num_chunks_this_seq <= 1) {
|
if (num_chunks_this_seq <= 1 || !ENABLE_PREFILL) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
|||||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||||
const int token_id = linear_index / hidden_size;
|
const int token_id = linear_index / hidden_size;
|
||||||
const int ori_bi = batch_id_per_token[token_id];
|
const int ori_bi = batch_id_per_token[token_id];
|
||||||
|
if (ori_bi == -1) continue;
|
||||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||||
const int bias = linear_index % hidden_size;
|
const int bias = linear_index % hidden_size;
|
||||||
const int hi = bias / head_size; // q + k + v
|
const int hi = bias / head_size; // q + k + v
|
||||||
@@ -84,15 +85,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
|||||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||||
if (block_idx < 0) {
|
if (block_idx < 0) {
|
||||||
printf(
|
return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
|
||||||
"%d %d %d %d\n",
|
|
||||||
block_idx,
|
|
||||||
write_seq_id,
|
|
||||||
ori_bi,
|
|
||||||
seq_lens_decoder[ori_bi],
|
|
||||||
token_id,
|
|
||||||
cu_seqlens_q[ori_bi]);
|
|
||||||
}
|
}
|
||||||
const int block_offset = write_seq_id % block_size;
|
const int block_offset = write_seq_id % block_size;
|
||||||
|
|
||||||
@@ -149,13 +142,13 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
|||||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
if (hi < num_heads) {
|
if (hi < num_heads) {
|
||||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < VecSize; i++) {
|
for (int i = 0; i < VecSize; i++) {
|
||||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||||
}
|
}
|
||||||
@@ -390,15 +383,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
|||||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||||
if (block_idx < 0) {
|
if (block_idx < 0) {
|
||||||
printf(
|
return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
|
||||||
"%d %d %d %d\n",
|
|
||||||
block_idx,
|
|
||||||
write_seq_id,
|
|
||||||
ori_bi,
|
|
||||||
seq_lens_decoder[ori_bi],
|
|
||||||
token_id,
|
|
||||||
cu_seqlens_q[ori_bi]);
|
|
||||||
}
|
}
|
||||||
const int block_offset = write_seq_id % block_size;
|
const int block_offset = write_seq_id % block_size;
|
||||||
|
|
||||||
@@ -525,15 +510,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
|||||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||||
if (block_idx < 0) {
|
if (block_idx < 0) {
|
||||||
printf(
|
return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||||
"Fatal Error!!!, block idx %d when write_seq_id is %d\n some key var "
|
|
||||||
"%d %d %d %d\n",
|
|
||||||
block_idx,
|
|
||||||
write_seq_id,
|
|
||||||
ori_bi,
|
|
||||||
seq_lens_decoder[ori_bi],
|
|
||||||
token_id,
|
|
||||||
cu_seqlens_q[ori_bi]);
|
|
||||||
}
|
}
|
||||||
const int block_offset = write_seq_id % block_size;
|
const int block_offset = write_seq_id % block_size;
|
||||||
|
|
||||||
|
|||||||
@@ -676,7 +676,7 @@ void SpeculateVerify(
|
|||||||
const paddle::Tensor &output_cum_offsets,
|
const paddle::Tensor &output_cum_offsets,
|
||||||
const paddle::Tensor &actual_candidate_len,
|
const paddle::Tensor &actual_candidate_len,
|
||||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode);
|
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts);
|
||||||
|
|
||||||
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
void SpeculateUpdate(const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ std::vector<paddle::DataType> SpeculateGetPaddingOffsetInferDtype(
|
|||||||
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
PD_BUILD_STATIC_OP(speculate_get_padding_offset)
|
||||||
.Inputs({"input_ids",
|
.Inputs({"input_ids",
|
||||||
"draft_tokens",
|
"draft_tokens",
|
||||||
|
"cum_offsets",
|
||||||
"token_num",
|
"token_num",
|
||||||
"seq_len",
|
"seq_len",
|
||||||
"seq_lens_encoder"})
|
"seq_lens_encoder"})
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ __global__ void speculate_verify(
|
|||||||
const int *output_cum_offsets, const int *actual_candidate_len,
|
const int *output_cum_offsets, const int *actual_candidate_len,
|
||||||
const int real_bsz, const int max_draft_tokens, const int end_length,
|
const int real_bsz, const int max_draft_tokens, const int end_length,
|
||||||
const int max_seq_len, const int max_candidate_len, const int verify_window,
|
const int max_seq_len, const int max_candidate_len, const int verify_window,
|
||||||
const bool prefill_one_step_stop, const bool benchmark_mode) {
|
const bool prefill_one_step_stop, const bool benchmark_mode, const bool accept_all_drafts) {
|
||||||
const int bid = threadIdx.x;
|
const int bid = threadIdx.x;
|
||||||
// verify and set stop flags
|
// verify and set stop flags
|
||||||
int accept_num_now = 1;
|
int accept_num_now = 1;
|
||||||
@@ -107,6 +107,24 @@ __global__ void speculate_verify(
|
|||||||
if (seq_lens_encoder[bid] != 0) {
|
if (seq_lens_encoder[bid] != 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (accept_all_drafts) {
|
||||||
|
// accept all draft tokens
|
||||||
|
step_idx[bid]++;
|
||||||
|
auto accept_token = draft_tokens_now[i + 1];
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = accept_token;
|
||||||
|
|
||||||
|
if (is_in_end(accept_token, end_tokens, end_length) ||
|
||||||
|
step_idx[bid] >= max_dec_len[bid]) {
|
||||||
|
stop_flags[bid] = true;
|
||||||
|
stop_flag_now_int = 1;
|
||||||
|
if (step_idx[bid] >= max_dec_len[bid])
|
||||||
|
accept_tokens[bid * max_draft_tokens + i] = end_tokens[0];
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
accept_num_now++;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (USE_TOPK) {
|
if (USE_TOPK) {
|
||||||
if (verify_tokens_now[i * max_candidate_len] ==
|
if (verify_tokens_now[i * max_candidate_len] ==
|
||||||
draft_tokens_now[i + 1]) {
|
draft_tokens_now[i + 1]) {
|
||||||
@@ -255,7 +273,7 @@ void SpeculateVerify(
|
|||||||
const paddle::Tensor &output_cum_offsets,
|
const paddle::Tensor &output_cum_offsets,
|
||||||
const paddle::Tensor &actual_candidate_len,
|
const paddle::Tensor &actual_candidate_len,
|
||||||
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
|
||||||
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
|
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode, bool accept_all_drafts) {
|
||||||
// printf("Enter speculate update\n");
|
// printf("Enter speculate update\n");
|
||||||
auto bsz = accept_tokens.shape()[0];
|
auto bsz = accept_tokens.shape()[0];
|
||||||
int real_bsz = seq_lens_this_time.shape()[0];
|
int real_bsz = seq_lens_this_time.shape()[0];
|
||||||
@@ -298,7 +316,7 @@ void SpeculateVerify(
|
|||||||
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
|
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
|
||||||
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
|
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
|
||||||
end_length, max_seq_len, max_candidate_len, verify_window,
|
end_length, max_seq_len, max_candidate_len, verify_window,
|
||||||
prefill_one_step_stop, benchmark_mode);
|
prefill_one_step_stop, benchmark_mode, accept_all_drafts);
|
||||||
} else {
|
} else {
|
||||||
speculate_verify<false, true>
|
speculate_verify<false, true>
|
||||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||||
@@ -314,7 +332,7 @@ void SpeculateVerify(
|
|||||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (enable_topp) {
|
if (enable_topp) {
|
||||||
@@ -332,7 +350,7 @@ void SpeculateVerify(
|
|||||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
|
||||||
} else {
|
} else {
|
||||||
speculate_verify<false, false>
|
speculate_verify<false, false>
|
||||||
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
|
||||||
@@ -348,7 +366,7 @@ void SpeculateVerify(
|
|||||||
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
|
||||||
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
|
||||||
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
real_bsz, max_draft_tokens, end_length, max_seq_len,
|
||||||
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
|
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode, accept_all_drafts);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,7 +381,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
|
|||||||
"actual_candidate_len", "actual_draft_token_nums", "topp"})
|
"actual_candidate_len", "actual_draft_token_nums", "topp"})
|
||||||
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
|
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
|
||||||
"stop_flags_out"})
|
"stop_flags_out"})
|
||||||
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
|
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool","accept_all_drafts: bool"})
|
||||||
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
|
||||||
{"accept_num", "accept_num_out"},
|
{"accept_num", "accept_num_out"},
|
||||||
{"step_idx", "step_idx_out"},
|
{"step_idx", "step_idx_out"},
|
||||||
|
|||||||
@@ -1150,7 +1150,14 @@ class FDConfig:
|
|||||||
# Initialize cuda graph capture list
|
# Initialize cuda graph capture list
|
||||||
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
if self.graph_opt_config.cudagraph_capture_sizes is None:
|
||||||
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
self.graph_opt_config._set_cudagraph_sizes(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||||
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
|
|
||||||
|
if self.speculative_config is not None and self.speculative_config.method == "mtp":
|
||||||
|
max_shape = self.parallel_config.max_num_seqs * (self.speculative_config.num_speculative_tokens + 1)
|
||||||
|
if max_shape % 2 == 1:
|
||||||
|
max_shape = max_shape + 1
|
||||||
|
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=min(512, max_shape))
|
||||||
|
else:
|
||||||
|
self.graph_opt_config.init_with_cudagrpah_size(max_num_seqs=self.parallel_config.max_num_seqs)
|
||||||
|
|
||||||
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
# TODO(wangmingkai02): change graph_opt_level=2 when using static mode with cinn
|
||||||
if self.graph_opt_config.graph_opt_level == 2:
|
if self.graph_opt_config.graph_opt_level == 2:
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class ForwardMeta:
|
|||||||
"shape": obj.shape,
|
"shape": obj.shape,
|
||||||
"dtype": str(obj.dtype),
|
"dtype": str(obj.dtype),
|
||||||
"place": str(obj.place),
|
"place": str(obj.place),
|
||||||
|
"content": obj if obj.numel() < 70 else "Too big to show",
|
||||||
}
|
}
|
||||||
return tensor_info
|
return tensor_info
|
||||||
elif isinstance(obj, (list, tuple)):
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
@@ -111,7 +112,7 @@ class CudaGraphPiecewiseBackend:
|
|||||||
entry.num_finished_warmup += 1
|
entry.num_finished_warmup += 1
|
||||||
entry.runnable(**kwargs)
|
entry.runnable(**kwargs)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[CUDA GRAPH] Warm up for batch size {entry.real_shape}, "
|
f"[CUDA GRAPH] [ID:{id(self)}] Warm up for batch size {entry.real_shape}, "
|
||||||
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
|
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -138,15 +139,17 @@ class CudaGraphPiecewiseBackend:
|
|||||||
real_shape = ids_remove_padding.shape[0]
|
real_shape = ids_remove_padding.shape[0]
|
||||||
padding_real_shape = self.real_shape_to_captured_size[real_shape]
|
padding_real_shape = self.real_shape_to_captured_size[real_shape]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[CUDA GRAPH] The actual real shape obtained by CUDAGraph is :{real_shape}, "
|
f"[CUDA GRAPH] [ID:{id(self)}] The actual real shape obtained by CUDAGraph is :{real_shape}, "
|
||||||
f"The padded shape is :{padding_real_shape}"
|
f"The padded shape is :{padding_real_shape}, If Padding :{real_shape != padding_real_shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
entry = self.concrete_size_entries.get(padding_real_shape)
|
entry = self.concrete_size_entries.get(padding_real_shape)
|
||||||
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
|
assert entry is not None, f"real shape:{padding_real_shape} is not in cuda graph capture list."
|
||||||
if entry.runnable is None:
|
if entry.runnable is None:
|
||||||
entry.runnable = self.runnable
|
entry.runnable = self.runnable
|
||||||
logger.debug(f"[CUDA GRAPH] New entry lazy initialize with real shape {padding_real_shape}")
|
logger.debug(
|
||||||
|
f"[CUDA GRAPH] [ID:{id(self)}] New entry lazy initialize with real shape {padding_real_shape}"
|
||||||
|
)
|
||||||
|
|
||||||
if not entry.use_cudagraph:
|
if not entry.use_cudagraph:
|
||||||
return entry.runnable(**kwargs)
|
return entry.runnable(**kwargs)
|
||||||
@@ -161,7 +164,7 @@ class CudaGraphPiecewiseBackend:
|
|||||||
entry.num_finished_warmup += 1
|
entry.num_finished_warmup += 1
|
||||||
entry.runnable(**kwargs)
|
entry.runnable(**kwargs)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[CUDA GRAPH] Warm up for real shape {padding_real_shape}, "
|
f"[CUDA GRAPH] [ID:{id(self)}] Warm up for real shape {padding_real_shape}, "
|
||||||
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
|
f"finished ({n + 1}/{entry.num_finished_warmup}) times"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -196,11 +199,11 @@ class CudaGraphPiecewiseBackend:
|
|||||||
|
|
||||||
# For CUDAGraph debug
|
# For CUDAGraph debug
|
||||||
# self._save_cudagrpah_dot_files(entry)
|
# self._save_cudagrpah_dot_files(entry)
|
||||||
logger.debug(f"[CUDA GRAPH] CUDAGraph captured for real shape {padding_real_shape}")
|
logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph captured for real shape {padding_real_shape}")
|
||||||
|
|
||||||
# Replay
|
# Replay
|
||||||
entry.cuda_graph.replay()
|
entry.cuda_graph.replay()
|
||||||
logger.debug(f"[CUDA GRAPH] CUDAGraph replayed for real shape {padding_real_shape}")
|
logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph replayed for real shape {padding_real_shape}")
|
||||||
if len(entry.output_buffers) == 1:
|
if len(entry.output_buffers) == 1:
|
||||||
return entry.output_buffers[0]
|
return entry.output_buffers[0]
|
||||||
return entry.output_buffers
|
return entry.output_buffers
|
||||||
@@ -214,16 +217,17 @@ class CudaGraphPiecewiseBackend:
|
|||||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape)
|
self.concrete_size_entries[shape] = ConcreteSizeEntry(real_shape=shape)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CUDA GRAPH] CUDAGraph capture list {self.cudagraph_capture_sizes}, " "Created all real shape entry."
|
f"[CUDA GRAPH] [ID:{id(self)}] CUDAGraph capture list {self.cudagraph_capture_sizes}, "
|
||||||
|
"Created all real shape entry."
|
||||||
)
|
)
|
||||||
|
|
||||||
def clear_graph(self):
|
def clear_graph(self):
|
||||||
""" """
|
""" """
|
||||||
# Clear graphs
|
# Clear graphs
|
||||||
for id, entry in self.concrete_size_entries.items():
|
for _id, entry in self.concrete_size_entries.items():
|
||||||
if entry.cuda_graph:
|
if entry.cuda_graph:
|
||||||
del entry.cuda_graph
|
del entry.cuda_graph
|
||||||
logger.debug(f"[CUDA GRAPH] The CUDAGraph with shape {id} has been cleared.")
|
logger.debug(f"[CUDA GRAPH] [ID:{id(self)}] The CUDAGraph with shape {_id} has been cleared.")
|
||||||
|
|
||||||
del self.concrete_size_entries
|
del self.concrete_size_entries
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
@@ -236,6 +240,6 @@ class CudaGraphPiecewiseBackend:
|
|||||||
log_dir = envs.FD_LOG_DIR
|
log_dir = envs.FD_LOG_DIR
|
||||||
if entry.cuda_graph:
|
if entry.cuda_graph:
|
||||||
entry.cuda_graph.print_to_dot_files(
|
entry.cuda_graph.print_to_dot_files(
|
||||||
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}",
|
f"./{log_dir}/GraphDotFiles/backend{id(self)}_shape{entry.real_shape}_time{time.perf_counter()}",
|
||||||
1 << 0,
|
1 << 0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class GraphOptBackend:
|
|||||||
self.runnable = runnable
|
self.runnable = runnable
|
||||||
self.fd_config = fd_config
|
self.fd_config = fd_config
|
||||||
|
|
||||||
self.max_captre_batch = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
|
self.max_captre_size = fd_config.graph_opt_config.cudagraph_capture_sizes[0]
|
||||||
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
if self.fd_config.graph_opt_config.graph_opt_level > 0:
|
||||||
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
# 1. Prepare cuda grpah input buffers (contain output of subgraphs)
|
||||||
|
|
||||||
@@ -138,9 +138,9 @@ class GraphOptBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert kwargs["forward_meta"].ids_remove_padding is not None
|
assert kwargs["forward_meta"].ids_remove_padding is not None
|
||||||
batch_size = kwargs["forward_meta"].ids_remove_padding.shape[0]
|
real_shape = kwargs["forward_meta"].ids_remove_padding.shape[0]
|
||||||
|
|
||||||
if (not kwargs["forward_meta"].step_use_cudagraph) or (batch_size > self.max_captre_batch):
|
if (not kwargs["forward_meta"].step_use_cudagraph) or (real_shape > self.max_captre_size):
|
||||||
return self.runnable(**kwargs)
|
return self.runnable(**kwargs)
|
||||||
else:
|
else:
|
||||||
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
return self.cudagraph_piecewise_backend.__call__(**kwargs)
|
||||||
|
|||||||
@@ -395,6 +395,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
share_inputs: List[paddle.Tensor],
|
share_inputs: List[paddle.Tensor],
|
||||||
|
accept_all_drafts: bool = False,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
@@ -451,6 +452,7 @@ class SpeculativeSampler(nn.Layer):
|
|||||||
self.speculative_verify_window,
|
self.speculative_verify_window,
|
||||||
True, # enable_topp
|
True, # enable_topp
|
||||||
self.speculative_benchmark_mode,
|
self.speculative_benchmark_mode,
|
||||||
|
accept_all_drafts,
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ from paddleformers.utils.log import logger
|
|||||||
|
|
||||||
from fastdeploy.config import FDConfig
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
from fastdeploy.model_executor.graph_optimization.decorator import (
|
||||||
|
support_graph_optimization,
|
||||||
|
)
|
||||||
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
|
from fastdeploy.model_executor.layers.mtp_linear import ParallelEHProjection
|
||||||
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
from fastdeploy.model_executor.layers.normalization import RMSNorm
|
||||||
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer
|
||||||
@@ -229,6 +232,7 @@ class Ernie4_5_MTPPretrainedModel(PretrainedModel):
|
|||||||
return mappings
|
return mappings
|
||||||
|
|
||||||
|
|
||||||
|
@support_graph_optimization
|
||||||
class Ernie4_5_MTPModel(nn.Layer):
|
class Ernie4_5_MTPModel(nn.Layer):
|
||||||
"""
|
"""
|
||||||
Ernie4_5_MTPModel
|
Ernie4_5_MTPModel
|
||||||
@@ -435,6 +439,10 @@ class Ernie4_5_MTPForCausalLM(ModelForCasualLM):
|
|||||||
"""
|
"""
|
||||||
forward
|
forward
|
||||||
"""
|
"""
|
||||||
hidden_states = self.ernie(ids_remove_padding, previous_hidden_states, forward_meta)
|
hidden_states = self.ernie(
|
||||||
|
ids_remove_padding=ids_remove_padding,
|
||||||
|
previous_hidden_states=previous_hidden_states,
|
||||||
|
forward_meta=forward_meta,
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -33,31 +33,33 @@ class Proposer(ABC):
|
|||||||
the speculative decoding framework
|
the speculative decoding framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
"""
|
"""
|
||||||
Init Speculative proposer
|
Init Speculative proposer
|
||||||
"""
|
"""
|
||||||
cfg.parallel_config.tp_group = None
|
fd_config.parallel_config.tp_group = None
|
||||||
cfg.parallel_config.ep_group = None
|
fd_config.parallel_config.ep_group = None
|
||||||
self.cfg = deepcopy(cfg)
|
self.fd_config = deepcopy(fd_config)
|
||||||
cfg.parallel_config.tp_group = dist.get_group(
|
fd_config.parallel_config.tp_group = dist.get_group(
|
||||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
cfg.parallel_config.ep_group = dist.get_group(
|
fd_config.parallel_config.ep_group = dist.get_group(
|
||||||
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
self.cfg.parallel_config.tp_group = dist.get_group(
|
self.fd_config.parallel_config.tp_group = dist.get_group(
|
||||||
cfg.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
fd_config.parallel_config.data_parallel_rank + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
self.cfg.parallel_config.ep_group = dist.get_group(
|
self.fd_config.parallel_config.ep_group = dist.get_group(
|
||||||
cfg.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
fd_config.parallel_config.data_parallel_size + envs.FD_TP_GROUP_GID_OFFSET
|
||||||
)
|
)
|
||||||
|
|
||||||
self.parallel_config = self.cfg.parallel_config
|
self.parallel_config = self.fd_config.parallel_config
|
||||||
self.model_config = self.cfg.model_config
|
self.model_config = self.fd_config.model_config
|
||||||
self.speculative_config = self.cfg.speculative_config
|
self.speculative_config = self.fd_config.speculative_config
|
||||||
self.cache_config = self.cfg.cache_config
|
self.cache_config = self.fd_config.cache_config
|
||||||
self.quant_config = self.cfg.quant_config
|
self.quant_config = self.fd_config.quant_config
|
||||||
|
self.graph_opt_config = self.fd_config.graph_opt_config
|
||||||
|
self.scheduler_config = self.fd_config.scheduler_config
|
||||||
|
|
||||||
self.max_num_seqs = self.parallel_config.max_num_seqs
|
self.max_num_seqs = self.parallel_config.max_num_seqs
|
||||||
self.max_model_len = self.parallel_config.max_model_len
|
self.max_model_len = self.parallel_config.max_model_len
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import paddle
|
|||||||
from paddleformers.utils.log import logger
|
from paddleformers.utils.log import logger
|
||||||
|
|
||||||
from fastdeploy import envs
|
from fastdeploy import envs
|
||||||
|
from fastdeploy.config import FDConfig
|
||||||
from fastdeploy.engine.request import Request, RequestType
|
from fastdeploy.engine.request import Request, RequestType
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
from fastdeploy.model_executor.layers.attention import get_attention_backend
|
||||||
@@ -31,6 +32,8 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|||||||
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
from fastdeploy.model_executor.layers.rotary_embedding import get_rope
|
||||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||||
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
from fastdeploy.model_executor.layers.sample.sampler import MTPSampler
|
||||||
|
from fastdeploy.model_executor.model_loader import get_model_loader
|
||||||
|
from fastdeploy.model_executor.models import ModelForCasualLM
|
||||||
from fastdeploy.model_executor.ops.gpu import (
|
from fastdeploy.model_executor.ops.gpu import (
|
||||||
draft_model_postprocess,
|
draft_model_postprocess,
|
||||||
draft_model_preprocess,
|
draft_model_preprocess,
|
||||||
@@ -52,12 +55,19 @@ class MTPProposer(Proposer):
|
|||||||
Proposer for Multi-Token-Prediction(MTP)
|
Proposer for Multi-Token-Prediction(MTP)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg, main_model, local_rank, device_id, target_model_inputs):
|
def __init__(
|
||||||
super().__init__(cfg)
|
self,
|
||||||
|
fd_config: FDConfig,
|
||||||
|
main_model: ModelForCasualLM,
|
||||||
|
local_rank: int,
|
||||||
|
device_id: int, # physical device id
|
||||||
|
target_model_inputs, # main model share inputs
|
||||||
|
):
|
||||||
|
super().__init__(fd_config)
|
||||||
self.num_main_model_layers = self.model_config.num_hidden_layers
|
self.num_main_model_layers = self.model_config.num_hidden_layers
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.device_id = device_id
|
self.device_id = device_id
|
||||||
self._update_cfg(main_model)
|
self._update_mtp_config(main_model)
|
||||||
self._load_model()
|
self._load_model()
|
||||||
self.target_model_inputs = target_model_inputs
|
self.target_model_inputs = target_model_inputs
|
||||||
self.mtp_strategy = self.speculative_config.mtp_strategy
|
self.mtp_strategy = self.speculative_config.mtp_strategy
|
||||||
@@ -65,16 +75,22 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
# [mixed, prefill, decoder]
|
# [mixed, prefill, decoder]
|
||||||
self.role = "mixed"
|
self.role = "mixed"
|
||||||
self.sampler = MTPSampler(cfg)
|
self.sampler = MTPSampler(fd_config)
|
||||||
self._init_model_inputs()
|
self._init_model_inputs()
|
||||||
|
|
||||||
|
# CUDA Graph
|
||||||
|
self.use_cudagraph = False # self.graph_opt_config.use_cudagraph
|
||||||
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||||
|
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||||
|
|
||||||
self.attn_backends: list[AttentionBackend] = []
|
self.attn_backends: list[AttentionBackend] = []
|
||||||
self._initialize_attn_backend()
|
self._initialize_attn_backend()
|
||||||
|
|
||||||
def _update_cfg(self, main_model):
|
def _update_mtp_config(self, main_model):
|
||||||
"""
|
"""
|
||||||
Update config for MTP from global config
|
Update config for MTP from global config
|
||||||
"""
|
"""
|
||||||
|
self.forward_meta: ForwardMeta = None
|
||||||
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
self.model_config.architectures[0] = self.model_config.architectures[0].replace("Moe", "MTP")
|
||||||
self.speculative_config.sharing_model = main_model
|
self.speculative_config.sharing_model = main_model
|
||||||
self.model_config.num_hidden_layers = 1
|
self.model_config.num_hidden_layers = 1
|
||||||
@@ -89,21 +105,18 @@ class MTPProposer(Proposer):
|
|||||||
"""
|
"""
|
||||||
Load MTP Layer
|
Load MTP Layer
|
||||||
"""
|
"""
|
||||||
from fastdeploy.model_executor.model_loader import get_model_loader
|
|
||||||
|
|
||||||
model_loader = get_model_loader(load_config=self.cfg.load_config)
|
model_loader = get_model_loader(load_config=self.fd_config.load_config)
|
||||||
self.model = model_loader.load_model(fd_config=self.cfg)
|
self.model = model_loader.load_model(fd_config=self.fd_config)
|
||||||
|
|
||||||
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
def dummy_prefill_inputs(self, num_tokens: int, batch_size: int, expected_decode_len: int):
|
||||||
"""Set dummy prefill inputs to model_inputs"""
|
"""Set dummy prefill inputs to model_inputs"""
|
||||||
max_dec_len = expected_decode_len + 1
|
max_dec_len = expected_decode_len + 1
|
||||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
input_length = min(
|
||||||
self.initialize_kv_cache()
|
|
||||||
full_length = min(
|
|
||||||
num_tokens // batch_size,
|
num_tokens // batch_size,
|
||||||
self.parallel_config.max_model_len - max_dec_len,
|
self.parallel_config.max_model_len - max_dec_len,
|
||||||
)
|
)
|
||||||
input_length = int(full_length * self.cache_config.kv_cache_ratio)
|
|
||||||
block_num = (
|
block_num = (
|
||||||
input_length + self.cache_config.block_size - 1
|
input_length + self.cache_config.block_size - 1
|
||||||
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
) // self.cache_config.block_size + self.cache_config.enc_dec_block_num
|
||||||
@@ -125,13 +138,15 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||||
|
|
||||||
def initialize_kv_cache(self):
|
def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False):
|
||||||
"""
|
"""
|
||||||
Initialize kv cache
|
Initialize kv cache
|
||||||
"""
|
"""
|
||||||
# prompt cache
|
|
||||||
|
self.num_gpu_blocks = int(main_model_num_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
||||||
self.cache_kvs = {}
|
self.cache_kvs = {}
|
||||||
|
|
||||||
|
# Get kv cache dtype
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
|
|
||||||
kv_cache_quant_type = None
|
kv_cache_quant_type = None
|
||||||
@@ -151,9 +166,7 @@ class MTPProposer(Proposer):
|
|||||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
if not self.parallel_config.do_profile and (
|
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||||
self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"
|
|
||||||
):
|
|
||||||
cache_kvs_list = []
|
cache_kvs_list = []
|
||||||
for i in range(
|
for i in range(
|
||||||
self.num_main_model_layers,
|
self.num_main_model_layers,
|
||||||
@@ -230,7 +243,7 @@ class MTPProposer(Proposer):
|
|||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
self.cfg,
|
self.fd_config,
|
||||||
kv_num_heads=self.model_config.kv_num_heads,
|
kv_num_heads=self.model_config.kv_num_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
@@ -243,7 +256,7 @@ class MTPProposer(Proposer):
|
|||||||
)
|
)
|
||||||
self.attn_backends.append(attn_backend)
|
self.attn_backends.append(attn_backend)
|
||||||
|
|
||||||
def clear_dummy_input(self):
|
def clear_mtp_cache(self):
|
||||||
"""
|
"""
|
||||||
Clear allocated cacheKV
|
Clear allocated cacheKV
|
||||||
"""
|
"""
|
||||||
@@ -251,15 +264,14 @@ class MTPProposer(Proposer):
|
|||||||
if self.forward_meta is not None:
|
if self.forward_meta is not None:
|
||||||
del self.forward_meta.caches
|
del self.forward_meta.caches
|
||||||
|
|
||||||
def update_block_num(self, num_gpu_blocks) -> None:
|
def update_mtp_block_num(self, num_gpu_blocks) -> None:
|
||||||
"""
|
"""
|
||||||
Update block num by theoretical calculation
|
Update MTP block num by theoretical calculation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Reset block table and kv cache with global block num
|
||||||
self.main_model_num_gpu_blocks = num_gpu_blocks
|
self.main_model_num_gpu_blocks = num_gpu_blocks
|
||||||
self.num_gpu_blocks = int(num_gpu_blocks * self.speculative_config.num_gpu_block_expand_ratio)
|
self.initialize_kv_cache(main_model_num_blocks=self.main_model_num_gpu_blocks)
|
||||||
if not (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
|
||||||
self.initialize_kv_cache()
|
|
||||||
|
|
||||||
# Reset free list
|
# Reset free list
|
||||||
free_list = list(
|
free_list = list(
|
||||||
@@ -276,7 +288,6 @@ class MTPProposer(Proposer):
|
|||||||
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
"free_list_len": paddle.full([1], self.free_list_len, dtype="int32"),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.parallel_config.do_profile = False
|
|
||||||
|
|
||||||
def _init_model_inputs(self):
|
def _init_model_inputs(self):
|
||||||
"""
|
"""
|
||||||
@@ -300,6 +311,8 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
self.model_inputs["stop_nums"] = paddle.clone(self.target_model_inputs["stop_nums"])
|
||||||
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
|
self.model_inputs["not_need_stop"] = paddle.to_tensor([False], dtype="bool", place="cpu")
|
||||||
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
|
self.model_inputs["pre_ids"] = paddle.clone(self.target_model_inputs["pre_ids"])
|
||||||
|
self.model_inputs["output_cum_offsets"] = paddle.clone(self.target_model_inputs["output_cum_offsets"])
|
||||||
|
self.model_inputs["output_padding_offset"] = paddle.clone(self.target_model_inputs["output_padding_offset"])
|
||||||
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
|
self.model_inputs["ids_remove_padding"] = paddle.clone(self.target_model_inputs["ids_remove_padding"])
|
||||||
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
|
self.model_inputs["batch_id_per_token"] = paddle.clone(self.target_model_inputs["batch_id_per_token"])
|
||||||
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
|
self.model_inputs["cu_seqlens_q"] = paddle.clone(self.target_model_inputs["cu_seqlens_q"])
|
||||||
@@ -308,6 +321,9 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
|
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.clone(
|
||||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||||
)
|
)
|
||||||
|
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||||
|
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||||
|
)
|
||||||
|
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
self.model_inputs["rope_emb"] = get_rope(
|
self.model_inputs["rope_emb"] = get_rope(
|
||||||
@@ -443,9 +459,6 @@ class MTPProposer(Proposer):
|
|||||||
"""
|
"""
|
||||||
Process inputs for prefill tasks and insert it to model_inputs buffer
|
Process inputs for prefill tasks and insert it to model_inputs buffer
|
||||||
"""
|
"""
|
||||||
# NOTE: Lazy initialize kv cache
|
|
||||||
if "caches" not in self.model_inputs:
|
|
||||||
self.initialize_kv_cache()
|
|
||||||
|
|
||||||
# TODO:Init role in initialize process
|
# TODO:Init role in initialize process
|
||||||
if req_dicts[-1].disaggregate_info is not None:
|
if req_dicts[-1].disaggregate_info is not None:
|
||||||
@@ -526,7 +539,7 @@ class MTPProposer(Proposer):
|
|||||||
request.get("block_tables"), dtype="int32"
|
request.get("block_tables"), dtype="int32"
|
||||||
)
|
)
|
||||||
self.model_inputs["not_need_stop"][0] = True
|
self.model_inputs["not_need_stop"][0] = True
|
||||||
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests]
|
self.model_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer
|
||||||
|
|
||||||
def _initialize_forward_meta(self):
|
def _initialize_forward_meta(self):
|
||||||
"""
|
"""
|
||||||
@@ -556,6 +569,33 @@ class MTPProposer(Proposer):
|
|||||||
for attn_backend in self.attn_backends:
|
for attn_backend in self.attn_backends:
|
||||||
attn_backend.init_attention_metadata(self.forward_meta)
|
attn_backend.init_attention_metadata(self.forward_meta)
|
||||||
|
|
||||||
|
# Update Batch type for cuda graph
|
||||||
|
only_decode_batch = True
|
||||||
|
prefill_exists = None
|
||||||
|
|
||||||
|
# Mix ep in single node
|
||||||
|
if self.fd_config.parallel_config.use_ep and self.fd_config.parallel_config.splitwise_role == "mixed":
|
||||||
|
only_decode_batch_list = []
|
||||||
|
prefill_exists = self.exist_prefill()
|
||||||
|
paddle.distributed.all_gather_object(only_decode_batch_list, not prefill_exists)
|
||||||
|
only_decode_batch = all(only_decode_batch_list)
|
||||||
|
self.fd_config.model_config.moe_phase.phase = "decode" if only_decode_batch else "prefill"
|
||||||
|
|
||||||
|
self.forward_meta.step_use_cudagraph = (
|
||||||
|
self.use_cudagraph
|
||||||
|
and only_decode_batch
|
||||||
|
and not (prefill_exists if prefill_exists is not None else self.exist_prefill())
|
||||||
|
)
|
||||||
|
|
||||||
|
def exist_prefill(self):
|
||||||
|
"""
|
||||||
|
check whether prefill stage exist
|
||||||
|
"""
|
||||||
|
if int(paddle.max(self.model_inputs["seq_lens_encoder"])) != 0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
def _prepare_inputs(self, full_hidden_states):
|
def _prepare_inputs(self, full_hidden_states):
|
||||||
"""
|
"""
|
||||||
Prepare MTP inputs
|
Prepare MTP inputs
|
||||||
@@ -599,10 +639,8 @@ class MTPProposer(Proposer):
|
|||||||
self.target_model_inputs["seq_lens_encoder"],
|
self.target_model_inputs["seq_lens_encoder"],
|
||||||
self.num_model_steps,
|
self.num_model_steps,
|
||||||
)
|
)
|
||||||
if isinstance(target_hidden_states, list):
|
|
||||||
target_hidden_states = target_hidden_states[0]
|
|
||||||
|
|
||||||
return target_hidden_states
|
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
||||||
|
|
||||||
def _post_process(self, sampled_token_ids):
|
def _post_process(self, sampled_token_ids):
|
||||||
"""
|
"""
|
||||||
@@ -633,7 +671,7 @@ class MTPProposer(Proposer):
|
|||||||
self.parallel_config.use_ep,
|
self.parallel_config.use_ep,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _propose(self, target_hidden_states):
|
def _propose(self):
|
||||||
"""
|
"""
|
||||||
Main process for MTP inference
|
Main process for MTP inference
|
||||||
"""
|
"""
|
||||||
@@ -663,10 +701,15 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
self.model_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||||
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
self.model_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||||
# for speculative decoding
|
# for speculative decoding
|
||||||
self.model_inputs["output_cum_offsets"] = output_cum_offsets
|
self.model_inputs["output_cum_offsets"].copy_(output_cum_offsets, False)
|
||||||
self.model_inputs["output_padding_offset"] = output_padding_offset
|
self.model_inputs["output_padding_offset"].copy_(output_padding_offset, False)
|
||||||
|
|
||||||
|
# Initialize forward meta data
|
||||||
self._initialize_forward_meta()
|
self._initialize_forward_meta()
|
||||||
|
|
||||||
|
# Padding inputs for cuda graph
|
||||||
|
self.padding_cudagraph_inputs()
|
||||||
|
|
||||||
# Get sampling metadata
|
# Get sampling metadata
|
||||||
self.sampling_metadata = SamplingMetadata(
|
self.sampling_metadata = SamplingMetadata(
|
||||||
temperature=self.model_inputs["temperature"],
|
temperature=self.model_inputs["temperature"],
|
||||||
@@ -687,9 +730,11 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
ids_remove_padding=self.model_inputs["ids_remove_padding"],
|
||||||
previous_hidden_states=target_hidden_states,
|
previous_hidden_states=self.model_inputs["target_hidden_states"],
|
||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
|
if self.use_cudagraph:
|
||||||
|
model_output = model_output[: self.real_token_num]
|
||||||
|
|
||||||
hidden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
model_output,
|
model_output,
|
||||||
@@ -721,7 +766,7 @@ class MTPProposer(Proposer):
|
|||||||
self._post_process(sampled_token_ids)
|
self._post_process(sampled_token_ids)
|
||||||
|
|
||||||
if substep != self.num_model_steps - 1:
|
if substep != self.num_model_steps - 1:
|
||||||
target_hidden_states = self._get_self_hidden_states(hidden_states)
|
self._get_self_hidden_states(hidden_states)
|
||||||
else:
|
else:
|
||||||
if hasattr(self.model, "empty_input_forward"):
|
if hasattr(self.model, "empty_input_forward"):
|
||||||
self.model.empty_input_forward()
|
self.model.empty_input_forward()
|
||||||
@@ -733,10 +778,7 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["seq_lens_this_time"],
|
self.model_inputs["seq_lens_this_time"],
|
||||||
self.model_inputs["step_idx"],
|
self.model_inputs["step_idx"],
|
||||||
)
|
)
|
||||||
if isinstance(target_hidden_states, list):
|
self.model_inputs["target_hidden_states"].copy_(target_hidden_states, False)
|
||||||
target_hidden_states = target_hidden_states[0]
|
|
||||||
|
|
||||||
return target_hidden_states
|
|
||||||
|
|
||||||
def update_task_chunk_prefill(self, task):
|
def update_task_chunk_prefill(self, task):
|
||||||
"""
|
"""
|
||||||
@@ -821,8 +863,8 @@ class MTPProposer(Proposer):
|
|||||||
|
|
||||||
def _run_impl(self, full_hidden_states):
|
def _run_impl(self, full_hidden_states):
|
||||||
""""""
|
""""""
|
||||||
target_hidden_states = self._prepare_inputs(full_hidden_states)
|
self._prepare_inputs(full_hidden_states)
|
||||||
self._propose(target_hidden_states=target_hidden_states)
|
self._propose()
|
||||||
self._update_status()
|
self._update_status()
|
||||||
if self.hybrid_mode:
|
if self.hybrid_mode:
|
||||||
self._extend_draft_token_with_ngram_match()
|
self._extend_draft_token_with_ngram_match()
|
||||||
@@ -830,3 +872,16 @@ class MTPProposer(Proposer):
|
|||||||
def is_chunk_prefill_enabled(self):
|
def is_chunk_prefill_enabled(self):
|
||||||
""""""
|
""""""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def padding_cudagraph_inputs(self) -> None:
|
||||||
|
"""
|
||||||
|
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||||
|
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||||
|
"""
|
||||||
|
# In init_attention_metadata, the decode buffer has already been cleared
|
||||||
|
|
||||||
|
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
||||||
|
if self.use_cudagraph:
|
||||||
|
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
||||||
|
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||||
|
return
|
||||||
|
|||||||
@@ -29,8 +29,8 @@ class NgramProposer(Proposer):
|
|||||||
Matching corresponding tokens in input and output as draft tokens.
|
Matching corresponding tokens in input and output as draft tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cfg: FDConfig):
|
def __init__(self, fd_config: FDConfig):
|
||||||
super().__init__(cfg)
|
super().__init__(fd_config)
|
||||||
self.max_ngram_size = self.speculative_config.max_ngram_size
|
self.max_ngram_size = self.speculative_config.max_ngram_size
|
||||||
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
|
||||||
|
|
||||||
|
|||||||
@@ -135,7 +135,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# self.kv_caches: list[paddle.Tensor] = []
|
# self.kv_caches: list[paddle.Tensor] = []
|
||||||
|
|
||||||
# Cuda Graph
|
# Cuda Graph
|
||||||
self.graph_opt_level = self.graph_opt_config.graph_opt_level
|
|
||||||
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
self.use_cudagraph = self.graph_opt_config.use_cudagraph
|
||||||
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
self.cudagraph_capture_sizes = list(reversed(self.graph_opt_config.cudagraph_capture_sizes))
|
||||||
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
self.sot_warmup_sizes = self.graph_opt_config.sot_warmup_sizes
|
||||||
@@ -155,7 +154,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# In the future, we will expand it as a list.
|
# In the future, we will expand it as a list.
|
||||||
self.attn_backends: list[AttentionBackend] = []
|
self.attn_backends: list[AttentionBackend] = []
|
||||||
# self.attn_metadatas: list[AttentionMetadata] = []
|
# self.attn_metadatas: list[AttentionMetadata] = []
|
||||||
self.initialize_attn_backend()
|
self._initialize_attn_backend()
|
||||||
|
|
||||||
# Forward meta store the global meta information of the forward
|
# Forward meta store the global meta information of the forward
|
||||||
self.forward_meta: ForwardMeta = None
|
self.forward_meta: ForwardMeta = None
|
||||||
@@ -876,7 +875,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
self.share_inputs["ids_remove_padding"].copy_(ids_remove_padding, False)
|
||||||
self.share_inputs["batch_id_per_token"].copy_(batch_id_per_token, False)
|
|
||||||
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
self.share_inputs["cu_seqlens_q"].copy_(cu_seqlens_q, False)
|
||||||
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
self.share_inputs["cu_seqlens_k"].copy_(cu_seqlens_k, False)
|
||||||
|
|
||||||
@@ -890,6 +888,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# Initialize forward meta data
|
# Initialize forward meta data
|
||||||
self.initialize_forward_meta()
|
self.initialize_forward_meta()
|
||||||
|
self.forward_meta.batch_id_per_token.copy_(batch_id_per_token, False)
|
||||||
|
|
||||||
# Get sampling metadata
|
# Get sampling metadata
|
||||||
self.sampling_metadata = SamplingMetadata(
|
self.sampling_metadata = SamplingMetadata(
|
||||||
@@ -992,7 +991,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
Initialize kv cache
|
Initialize kv cache
|
||||||
"""
|
"""
|
||||||
cache_kvs = {}
|
cache_kvs = {}
|
||||||
max_block_num = self.num_gpu_blocks
|
|
||||||
|
|
||||||
# Get kv cache dtype
|
# Get kv cache dtype
|
||||||
cache_type = self.parallel_config.dtype
|
cache_type = self.parallel_config.dtype
|
||||||
@@ -1008,7 +1006,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
|
|
||||||
# Get kv cache shape
|
# Get kv cache shape
|
||||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
max_num_blocks=self.num_gpu_blocks, kv_cache_quant_type=kv_cache_quant_type
|
||||||
)
|
)
|
||||||
if kv_cache_quant_type == "block_wise_fp8":
|
if kv_cache_quant_type == "block_wise_fp8":
|
||||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||||
@@ -1055,7 +1053,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
del value
|
del value
|
||||||
paddle.device.cuda.empty_cache()
|
paddle.device.cuda.empty_cache()
|
||||||
|
|
||||||
def initialize_attn_backend(self) -> None:
|
def _initialize_attn_backend(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize attention backends
|
Initialize attention backends
|
||||||
"""
|
"""
|
||||||
@@ -1099,6 +1097,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
batch_size: paddle.Tensor,
|
batch_size: paddle.Tensor,
|
||||||
expected_decode_len: int = 1,
|
expected_decode_len: int = 1,
|
||||||
in_capturing: bool = False,
|
in_capturing: bool = False,
|
||||||
|
accept_all_drafts: bool = False,
|
||||||
) -> paddle.Tensor:
|
) -> paddle.Tensor:
|
||||||
"""
|
"""
|
||||||
Use dummy inputs to run before formal execution.
|
Use dummy inputs to run before formal execution.
|
||||||
@@ -1106,6 +1105,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
num_tokens:
|
num_tokens:
|
||||||
expected_decode_len: Expected number of tokens generated
|
expected_decode_len: Expected number of tokens generated
|
||||||
in_capturing: Is cuda graph in capturing state
|
in_capturing: Is cuda graph in capturing state
|
||||||
|
accept_all_drafts: Target model will accept all draft tokens
|
||||||
"""
|
"""
|
||||||
self._dummy_prefill_inputs(
|
self._dummy_prefill_inputs(
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
@@ -1134,12 +1134,16 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["image_features"],
|
self.share_inputs["image_features"],
|
||||||
self.forward_meta,
|
self.forward_meta,
|
||||||
)
|
)
|
||||||
|
if self.use_cudagraph:
|
||||||
|
model_output = model_output[: self.real_token_num]
|
||||||
hidden_states = model_output
|
hidden_states = model_output
|
||||||
else:
|
else:
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
|
if self.use_cudagraph:
|
||||||
|
model_output = model_output[: self.real_token_num]
|
||||||
|
|
||||||
hidden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
model_output,
|
model_output,
|
||||||
@@ -1179,6 +1183,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.sampling_metadata,
|
self.sampling_metadata,
|
||||||
self.parallel_config.max_model_len,
|
self.parallel_config.max_model_len,
|
||||||
self.share_inputs,
|
self.share_inputs,
|
||||||
|
accept_all_drafts,
|
||||||
)
|
)
|
||||||
sampler_output = None
|
sampler_output = None
|
||||||
if self.parallel_config.tensor_parallel_size > 1:
|
if self.parallel_config.tensor_parallel_size > 1:
|
||||||
@@ -1339,14 +1344,55 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
time_before_capture = time.perf_counter()
|
time_before_capture = time.perf_counter()
|
||||||
expected_decode_len = 1
|
expected_decode_len = 1
|
||||||
capture_sizes = self.cudagraph_capture_sizes.copy()
|
capture_sizes = self.cudagraph_capture_sizes.copy()
|
||||||
for batch_size in sorted(capture_sizes, reverse=True):
|
|
||||||
self._dummy_run(
|
if self.speculative_decoding and self.speculative_method == "mtp":
|
||||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
# Capture Target Model without bsz 1
|
||||||
batch_size=batch_size,
|
for batch_size in sorted(capture_sizes, reverse=True):
|
||||||
in_capturing=True,
|
if batch_size == 1:
|
||||||
expected_decode_len=expected_decode_len,
|
logger.info("Skip token_num = 1, when capture target model for mtp")
|
||||||
)
|
else:
|
||||||
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
|
assert batch_size % 2 == 0
|
||||||
|
self._dummy_run(
|
||||||
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
|
batch_size=int(batch_size / 2),
|
||||||
|
in_capturing=True,
|
||||||
|
expected_decode_len=1,
|
||||||
|
)
|
||||||
|
logger.info(f"Warm up the Target model with the num_tokens:{batch_size}, expected_decode_len:{1}")
|
||||||
|
# Capture Draft Model without bsz 1
|
||||||
|
# NOTE(liujundong): expected_decode_len = 1, will affect mtp capture in cudagraph
|
||||||
|
for batch_size in sorted(capture_sizes, reverse=True):
|
||||||
|
if batch_size == 1:
|
||||||
|
logger.info("Skip token_num = 1, when capture Draft model for mtp")
|
||||||
|
else:
|
||||||
|
assert batch_size % 2 == 0
|
||||||
|
self._dummy_run(
|
||||||
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
|
batch_size=int(batch_size / 2),
|
||||||
|
in_capturing=True,
|
||||||
|
expected_decode_len=3,
|
||||||
|
accept_all_drafts=True,
|
||||||
|
)
|
||||||
|
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
|
||||||
|
# Capture Draft Model with bsz 1
|
||||||
|
if 1 in capture_sizes:
|
||||||
|
self._dummy_run(
|
||||||
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
|
batch_size=int(1),
|
||||||
|
in_capturing=True,
|
||||||
|
expected_decode_len=3,
|
||||||
|
accept_all_drafts=False,
|
||||||
|
)
|
||||||
|
logger.info(f"Warm up the Draft model with the num_tokens:{batch_size}, expected_decode_len:{3}")
|
||||||
|
else:
|
||||||
|
for batch_size in sorted(capture_sizes, reverse=True):
|
||||||
|
self._dummy_run(
|
||||||
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
|
batch_size=batch_size,
|
||||||
|
in_capturing=True,
|
||||||
|
expected_decode_len=expected_decode_len,
|
||||||
|
)
|
||||||
|
logger.info(f"Warm up the model with the batch size:{batch_size}, num tokens:{expected_decode_len}")
|
||||||
|
|
||||||
time_after_capture = time.perf_counter()
|
time_after_capture = time.perf_counter()
|
||||||
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
logger.info(f"Cuda Graph capturing took {time_after_capture - time_before_capture} seconds")
|
||||||
@@ -1427,12 +1473,16 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["image_features"],
|
self.share_inputs["image_features"],
|
||||||
self.forward_meta,
|
self.forward_meta,
|
||||||
)
|
)
|
||||||
|
if self.use_cudagraph:
|
||||||
|
model_output = model_output[: self.real_token_num]
|
||||||
hidden_states = model_output
|
hidden_states = model_output
|
||||||
else:
|
else:
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
ids_remove_padding=self.share_inputs["ids_remove_padding"],
|
||||||
forward_meta=self.forward_meta,
|
forward_meta=self.forward_meta,
|
||||||
)
|
)
|
||||||
|
if self.use_cudagraph:
|
||||||
|
model_output = model_output[: self.real_token_num]
|
||||||
hidden_states = rebuild_padding(
|
hidden_states = rebuild_padding(
|
||||||
model_output,
|
model_output,
|
||||||
self.share_inputs["cu_seqlens_q"],
|
self.share_inputs["cu_seqlens_q"],
|
||||||
@@ -1628,20 +1678,22 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# TODO(gongshaotian): Optimize the management logic of kvcache
|
# TODO(gongshaotian): Optimize the management logic of kvcache
|
||||||
self.num_gpu_blocks = self.parallel_config.total_block_num
|
self.num_gpu_blocks = self.parallel_config.total_block_num
|
||||||
self.initialize_kv_cache(profile=True)
|
self.initialize_kv_cache(profile=True)
|
||||||
|
if self.speculative_method in ["mtp"]:
|
||||||
|
self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True)
|
||||||
|
|
||||||
# 1. Profile with multimodal encoder & encoder cache
|
# 1. Profile with multimodal encoder & encoder cache
|
||||||
|
|
||||||
# 2. Dummy run
|
# 2. Dummy run
|
||||||
self._dummy_run(
|
self._dummy_run(
|
||||||
num_tokens=self.parallel_config.max_num_batched_tokens,
|
num_tokens=self.parallel_config.max_num_batched_tokens,
|
||||||
batch_size=min(self.parallel_config.max_num_seqs, 3),
|
batch_size=self.parallel_config.max_num_seqs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. gc
|
# 3. gc
|
||||||
self.clear_cache()
|
self.clear_cache()
|
||||||
|
|
||||||
if self.speculative_method in ["mtp"]:
|
if self.speculative_method in ["mtp"]:
|
||||||
self.proposer.clear_dummy_input()
|
self.proposer.clear_mtp_cache()
|
||||||
|
|
||||||
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
|
def update_share_input_block_num(self, num_gpu_blocks: int) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -1671,7 +1723,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_method in ["mtp"]:
|
if self.speculative_method in ["mtp"]:
|
||||||
self.proposer.update_block_num(num_gpu_blocks)
|
self.proposer.update_mtp_block_num(num_gpu_blocks)
|
||||||
|
|
||||||
def cal_theortical_kvcache(self):
|
def cal_theortical_kvcache(self):
|
||||||
"""
|
"""
|
||||||
@@ -1756,6 +1808,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
# To adapt to CUDA Graph, keep the forward pass at the maximum batch size.
|
||||||
if self.use_cudagraph:
|
if self.use_cudagraph:
|
||||||
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
self.forward_meta.seq_lens_this_time = self.seq_lens_this_time_buffer
|
||||||
|
self.real_token_num = self.forward_meta.ids_remove_padding.shape[0]
|
||||||
return
|
return
|
||||||
|
|
||||||
def _init_image_preprocess(self) -> None:
|
def _init_image_preprocess(self) -> None:
|
||||||
|
|||||||
@@ -207,7 +207,7 @@ class GpuWorker(WorkerBase):
|
|||||||
"""
|
"""
|
||||||
Perform the warm-up and the graph optimization
|
Perform the warm-up and the graph optimization
|
||||||
"""
|
"""
|
||||||
if self.model_runner.graph_opt_level >= 1:
|
if self.fd_config.graph_opt_config.graph_opt_level >= 1:
|
||||||
self.model_runner.sot_warmup()
|
self.model_runner.sot_warmup()
|
||||||
# Triger cuda grpah capture
|
# Triger cuda grpah capture
|
||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
|
|||||||
Reference in New Issue
Block a user