mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Speculative Decoding][MTP] Support static CacheKV C8 quantization and optimize memory usage (#5155)
* support static cachekv c8 quantization in mtp mode * optimize memory allocation
This commit is contained in:
@@ -602,7 +602,8 @@ template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool IsFP8 = false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamic = true>
|
||||
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
// gqa_group_size, head_size]
|
||||
@@ -662,8 +663,6 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
(head_idx - num_heads) % gqa_group_size * block_size +
|
||||
block_offset;
|
||||
}
|
||||
T* cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T* cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
@@ -811,25 +810,34 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
if constexpr (IsDynamic) {
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(bias_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(bias_vec2[i]));
|
||||
}
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(bias_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(bias_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max =
|
||||
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max =
|
||||
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
scale = __hdiv(448, local_max);
|
||||
T* cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T* cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
scale = __ldg(&cache_k_scale[kv_head_idx]);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
scale = __ldg(&cache_v_scale[kv_head_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,32 +17,32 @@
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps,
|
||||
const bool rope_3d) {
|
||||
int output_inner_dim = num_heads + 2 * kv_num_heads;
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? token_num * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
@@ -55,35 +55,34 @@ void append_speculate_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
if (use_neox_style) {
|
||||
PD_THROW(
|
||||
"append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
PD_THROW("append_speculate_cache_rope_qk_norm not support neox rope yet");
|
||||
} else {
|
||||
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
|
||||
append_speculate_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
<<<grid_size, block_dim, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales,
|
||||
qkv_biases,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
output_inner_dim,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,33 +174,33 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* cache_k_scale,
|
||||
T* cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
template <typename T, bool IsDynamic = true>
|
||||
void append_speculate_cache_fp8_rope(const T* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* cache_k_scale,
|
||||
T* cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
@@ -220,7 +219,12 @@ void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
|
||||
num_heads,
|
||||
block_size,
|
||||
kv_num_heads);
|
||||
append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel<T, 4, 0, 128, true>
|
||||
append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
true,
|
||||
IsDynamic>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -247,7 +251,7 @@ void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE, bool IsFP8=false>
|
||||
template <typename T, typename QKV_TYPE, bool IsFP8 = false>
|
||||
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -489,7 +493,6 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
auto kv_num_heads = meta_data.kv_num_heads;
|
||||
|
||||
|
||||
const float* cos_emb =
|
||||
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
|
||||
const float* sin_emb;
|
||||
@@ -515,8 +518,8 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
@@ -532,209 +535,243 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
append_speculate_cache_fp8_dynamic_rope(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps
|
||||
);
|
||||
append_speculate_cache_fp8_rope<DataType_, true>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_v_scale.get().data<T>())),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_fp8_rope<DataType_, false>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_v_scale.get().data<T>())),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
"speculate_append_decode_cache_rope_qk_norm just supports "
|
||||
"cache_quant_type "
|
||||
"none/block_wise_fp8/cache_fp8");
|
||||
}
|
||||
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
append_speculate_cache_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
append_speculate_cache_int8_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
append_speculate_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
append_speculate_cache_fp8_dynamic_rope(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
|
||||
nullptr, // q_norm_weight
|
||||
nullptr, // k_norm_weight
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps
|
||||
);
|
||||
append_speculate_cache_fp8_rope(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_v_scale.get().data<T>())),
|
||||
nullptr, // q_norm_weight
|
||||
nullptr, // k_norm_weight
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(const_cast<T*>(qkv_out->data<T>())),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
PD_THROW(
|
||||
"cache_quant_type_str should be one of [none, cache_int8, "
|
||||
"cache_int4_zp]");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -827,7 +864,6 @@ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const float rms_norm_eps);
|
||||
|
||||
|
||||
template void
|
||||
SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
|
||||
@@ -198,6 +198,8 @@ class ModelConfig:
|
||||
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
||||
self.revision = None
|
||||
self.prefix_layer_name = "layers"
|
||||
self.kv_cache_quant_scale_path = ""
|
||||
|
||||
self.partial_rotary_factor: float = 1.0
|
||||
self.num_nextn_predict_layers = 0
|
||||
@@ -244,6 +246,7 @@ class ModelConfig:
|
||||
|
||||
self.enable_mm = is_multimodal_model
|
||||
|
||||
self.kv_cache_quant_scale_path = os.path.join(self.model, "kv_cache_scale.json")
|
||||
if self.runner_type == "pooling":
|
||||
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
|
||||
|
||||
@@ -1589,6 +1592,10 @@ class FDConfig:
|
||||
else:
|
||||
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
|
||||
|
||||
self.scheduler_config.max_chunk_len = (
|
||||
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_extra_num_batched_tokens
|
||||
)
|
||||
|
||||
if self.long_prefill_token_threshold == 0:
|
||||
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
|
||||
|
||||
|
||||
@@ -475,15 +475,16 @@ def deal_state_dict(state_dict):
|
||||
src_tensor._share_data_with(dst_tensor)
|
||||
|
||||
|
||||
def load_cache_scale(model_path, fd_config, state_dict):
|
||||
file_path = os.path.join(model_path, "kv_cache_scale.json")
|
||||
def load_cache_scale(fd_config, state_dict):
|
||||
file_path = fd_config.model_config.kv_cache_quant_scale_path
|
||||
prefix_layer_name = fd_config.model_config.prefix_layer_name
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for i in range(fd_config.model_config.num_hidden_layers):
|
||||
|
||||
k_scale_name = f"ernie.layers.{i}.self_attn.cachek_matmul.activation_scale"
|
||||
v_scale_name = f"ernie.layers.{i}.self_attn.cachev_matmul.activation_scale"
|
||||
k_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachek_matmul.activation_scale"
|
||||
v_scale_name = f"ernie.{prefix_layer_name}.{i}.self_attn.cachev_matmul.activation_scale"
|
||||
|
||||
k_scale = data[k_scale_name]
|
||||
k_scale_tensor = paddle.to_tensor(k_scale, dtype=paddle.get_default_dtype())
|
||||
@@ -547,6 +548,6 @@ def load_composite_checkpoint(
|
||||
if hasattr(fd_config.quant_config, "kv_cache_quant_type"):
|
||||
kv_cache_quant_type = fd_config.quant_config.kv_cache_quant_type
|
||||
if kv_cache_quant_type == "float8_e4m3fn":
|
||||
load_cache_scale(model_path, fd_config, state_dict)
|
||||
load_cache_scale(fd_config, state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
@@ -268,7 +268,9 @@ class SchedulerConfig:
|
||||
Exception: If invalid scheduler type is specified
|
||||
"""
|
||||
self.name = "local" # "local" for LocalScheduler or "global" for GlobalScheduler
|
||||
self.max_num_batched_tokens = 2048
|
||||
self.max_num_batched_tokens = 2048 # base token_num for text inputs
|
||||
self.max_extra_num_batched_tokens = 16384 # extra token_num for multimodal inputs
|
||||
self.max_chunk_len = 18432 # max supported token_num = max_num_batched_tokens + max_extra_num_batched_tokens
|
||||
self.max_num_seqs = 34
|
||||
self.splitwise_role = "mixed"
|
||||
self.config = None
|
||||
|
||||
@@ -104,6 +104,7 @@ class MTPProposer(Proposer):
|
||||
self.model_config.num_hidden_layers = 1
|
||||
self.model_config.model = self.speculative_config.model
|
||||
self.model_config.pretrained_config.prefix_name = "ernie.mtp_block"
|
||||
self.model_config.prefix_layer_name = "mtp_block"
|
||||
if self.speculative_config.quantization != "":
|
||||
self.model_config.quantization = self.speculative_config.quantization
|
||||
self.model_config.start_layer_index = self.num_main_model_layers
|
||||
@@ -354,7 +355,7 @@ class MTPProposer(Proposer):
|
||||
self.target_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["target_hidden_states"] = paddle.full(
|
||||
[self.max_model_len * self.fd_config.max_prefill_batch, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
[self.fd_config.scheduler_config.max_chunk_len, self.model_config.hidden_size], 0, dtype="bfloat16"
|
||||
)
|
||||
|
||||
tmp_position_ids = paddle.arange(self.model_config.max_model_len).reshape((1, -1))
|
||||
|
||||
Reference in New Issue
Block a user