mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Graph Optimization][Speculative Decoding] Update yaml and fix typo (#4612)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
max_model_len: 32768
|
||||
max_num_seqs: 96
|
||||
gpu_memory_utilization: 0.9
|
||||
gpu_memory_utilization: 0.85
|
||||
kv_cache_ratio: 0.71
|
||||
tensor_parallel_size: 4
|
||||
quantization: wint4
|
||||
|
||||
@@ -74,7 +74,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const int token_id = linear_index / hidden_size;
|
||||
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
@@ -378,7 +378,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
@@ -508,7 +508,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / half_hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (ori_bi == -1) continue; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
|
||||
Reference in New Issue
Block a user