[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)

* support static cachekv c8 quantization in mtp mode

* optimize memory allocation
This commit is contained in:
freeliuzc
2025-11-21 15:10:13 +08:00
committed by GitHub
parent 3c36283d7d
commit 2d1dade5e2
6 changed files with 350 additions and 295 deletions

View File

@@ -602,7 +602,8 @@ template <typename T,
int VecSize = 4,
int RoundType = 0,
int HeadDim = 128,
bool IsFP8 = false>
bool IsFP8 = false,
bool IsDynamic = true>
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
// gqa_group_size, head_size]
@@ -662,8 +663,6 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
(head_idx - num_heads) % gqa_group_size * block_size +
block_offset;
}
T* cache_k_scale_now = cache_k_scale + cache_offset;
T* cache_v_scale_now = cache_v_scale + cache_offset;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
@@ -811,25 +810,34 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
}
}
// reduce max, 1 head per warp
T local_max = -INFINITY;
if constexpr (IsDynamic) {
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(bias_vec1[i]));
local_max = __hmax(local_max, __habs(bias_vec2[i]));
}
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(bias_vec1[i]));
local_max = __hmax(local_max, __habs(bias_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max =
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
local_max =
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}
scale = __hdiv(448, local_max);
if (lane_id == 0) {
scale = __hdiv(448, local_max);
T* cache_k_scale_now = cache_k_scale + cache_offset;
T* cache_v_scale_now = cache_v_scale + cache_offset;
if (lane_id == 0) {
if (head_idx < num_heads + gqa_group_size) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}
} else {
if (head_idx < num_heads + gqa_group_size) {
cache_k_scale_now[0] = __hdiv(1, scale);
scale = __ldg(&cache_k_scale[kv_head_idx]);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
scale = __ldg(&cache_v_scale[kv_head_idx]);
}
}

View File

@@ -17,32 +17,32 @@
template <typename T, typename QKV_TYPE>
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
T* key_cache,
T* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
const float* qkv_out_scales,
const T* qkv_biases,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool use_neox_style,
const float* q_norm_weight,
const float* k_norm_weight,
const float rms_norm_eps,
const bool rope_3d) {
int output_inner_dim = num_heads + 2 * kv_num_heads;
const uint32_t elem_nums =
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
@@ -55,35 +55,34 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
PD_THROW(
"append_speculate_cache_rope_qk_norm not support neox rope yet");
PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet");
} else {
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
rope_3d);
<<<grid_size, block_dim, 0, stream>>>(qkv,
key_cache,
value_cache,
qkv_out,
block_tables,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
output_inner_dim,
dim_head,
block_size,
elem_nums,
kv_num_heads,
q_norm_weight,
k_norm_weight,
rms_norm_eps,
rope_3d);
}
}
@@ -175,33 +174,33 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
}
}
template <typename T>
void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
T* cache_k_scale,
T* cache_v_scale,
const float* q_norm_weight,
const float* k_norm_weight,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool rope_3d,
const float rms_norm_eps) {
template <typename T, bool IsDynamic = true>
void append_speculate_cache_fp8_rope(const T* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
T* qkv_out,
const int* block_tables,
const int* batch_id_per_token,
const int* cu_seqlens_q,
const int* seq_lens,
const int* seq_lens_encoder,
const float* cos_emb,
const float* sin_emb,
T* cache_k_scale,
T* cache_v_scale,
const float* q_norm_weight,
const float* k_norm_weight,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int block_size,
const int bsz,
const int token_num,
const cudaStream_t& stream,
const bool rope_3d,
const float rms_norm_eps) {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
@@ -220,7 +219,12 @@ void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
num_heads,
block_size,
kv_num_heads);
append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel<T, 4, 0, 128, true>
append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel<T,
4,
0,
128,
true,
IsDynamic>
<<<grids, num_warps * 32, 0, stream>>>(qkv,
key_cache,
value_cache,
@@ -247,7 +251,7 @@ void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE, bool IsFP8=false>
template <typename T, typename QKV_TYPE, bool IsFP8 = false>
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
uint8_t* key_cache,
uint8_t* value_cache,
@@ -489,7 +493,6 @@ void SpeculateWriteCacheWithRoPEKernel(
auto num_heads = meta_data.q_num_heads;
auto kv_num_heads = meta_data.kv_num_heads;
const float* cos_emb =
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
const float* sin_emb;
@@ -515,8 +518,8 @@ void SpeculateWriteCacheWithRoPEKernel(
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
@@ -532,209 +535,243 @@ void SpeculateWriteCacheWithRoPEKernel(
rms_norm_eps,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
append_speculate_cache_fp8_dynamic_rope(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
rope_3d,
rms_norm_eps
);
append_speculate_cache_fp8_rope<DataType_, true>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_v_scale.get().data<T>())),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_fp8_rope<DataType_, false>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_v_scale.get().data<T>())),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
rope_3d,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
"speculate_append_decode_cache_rope_qk_norm just supports "
"cache_quant_type "
"none/block_wise_fp8/cache_fp8");
}
} else {
if (cache_quant_type_str == "none") {
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
append_speculate_cache_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_int8") {
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
append_speculate_cache_int8_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "cache_fp8") {
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
append_speculate_cache_fp8_dynamic_rope(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
nullptr, // q_norm_weight
nullptr, // k_norm_weight
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
rope_3d,
rms_norm_eps
);
append_speculate_cache_fp8_rope(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_v_scale.get().data<T>())),
nullptr, // q_norm_weight
nullptr, // k_norm_weight
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
append_speculate_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
qkv_biases ? reinterpret_cast<DataType_*>(
const_cast<T*>(qkv_biases.get().data<T>()))
: nullptr,
cache_k_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_scale.get().data<T>()))
: nullptr,
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
cache_v_scale ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_scale.get().data<T>()))
: nullptr,
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
cache_k_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_k_zp.get().data<T>()))
: nullptr,
cache_v_zp ? reinterpret_cast<DataType_*>(
const_cast<T*>(cache_v_zp.get().data<T>()))
: nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads,
dim_head,
block_size,
bsz,
token_nums,
stream,
use_neox_rotary_style,
rope_3d);
} else {
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
PD_THROW(
"cache_quant_type_str should be one of [none, cache_int8, "
"cache_int4_zp]");
}
}
}
@@ -827,7 +864,6 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
const paddle::optional<paddle::Tensor>& k_norm_weight,
const float rms_norm_eps);
template void
SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
const AppendAttnMetaData& meta_data,

View File

@@ -198,6 +198,8 @@ class ModelConfig:
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
self.revision = None
self.prefix_layer_name = "layers"
self.kv_cache_quant_scale_path = ""
self.partial_rotary_factor: float = 1.0
self.num_nextn_predict_layers = 0
@@ -244,6 +246,7 @@ class ModelConfig:
self.enable_mm = is_multimodal_model
self.kv_cache_quant_scale_path = os.path.join(self.model, "kv_cache_scale.json")
if self.runner_type == "pooling":
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
@@ -1589,6 +1592,10 @@ class FDConfig:
else:
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
self.scheduler_config.max_chunk_len = (
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
)
if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)

View File

@@ -475,15 +475,16 @@ def deal_state_dict(state_dict):
src_tensor._share_data_with(dst_tensor)
def load_cache_scale(model_path, fd_config, state_dict):
file_path = os.path.join(model_path, "kv_cache_scale.json")
def load_cache_scale(fd_config, state_dict):
file_path = fd_config.model_config.kv_cache_quant_scale_path
prefix_layer_name = fd_config.model_config.prefix_layer_name
if os.path.exists(file_path):
with open(file_path, "r") as f:
data = json.load(f)
for i in range(fd_config.model_config.num_hidden_layers):
k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale"
v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale"
k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale"
v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale"
k_scale = data[k_scale_name]
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
@@ -547,6 +548,6 @@ def load_composite_checkpoint(
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
if kv_cache_quant_type == "float8_e4m3fn":
load_cache_scale(model_path, fd_config, state_dict)
load_cache_scale(fd_config, state_dict)
return state_dict

View File

@@ -268,7 +268,9 @@ class SchedulerConfig:
Exception: If invalid scheduler type is specified
"""
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
self.max_num_batched_tokens = 2048
self.max_num_batched_tokens = 2048 # base token_num for text inputs
self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs
self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens
self.max_num_seqs = 34
self.splitwise_role = "mixed"
self.config = None

View File

@@ -104,6 +104,7 @@ class MTPProposer(Proposer):
self.model_config.num_hidden_layers = 1
self.model_config.model = self.speculative_config.model
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
self.model_config.prefix_layer_name = "mtp_block"
if self.speculative_config.quantization != "":
self.model_config.quantization = self.speculative_config.quantization
self.model_config.start_layer_index = self.num_main_model_layers
@@ -354,7 +355,7 @@ class MTPProposer(Proposer):
self.target_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["target_hidden_states"] = paddle.full(
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
)
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))