mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[MTP][Cfp8]supports spec dynamic cfp8 (#4290)
* supports spec dynamic cfp8 * supports spec dynamic cfp8 --------- Co-authored-by: freeliuzc <lzc842650834@gmail.com>
This commit is contained in:
@@ -585,6 +585,273 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool IsFP8 = false>
|
||||
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
// gqa_group_size, head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
T* __restrict__ cache_k_scale,
|
||||
T* __restrict__ cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int gqa_group_size,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid] + token_id - start_token_idx;
|
||||
if (write_seq_id == 0) return;
|
||||
const int* block_table_now = block_tables + bid * max_blocks_per_seq;
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * gqa_group_size) {
|
||||
cache_offset = block_idx * gqa_group_size * block_size + (head_idx - num_heads) % gqa_group_size * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT bias_vec;
|
||||
LoadOutScaleT out_scale_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
T* qkv_out_now = qkv_out + token_id * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(static_cast<float>(bias_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * gqa_group_size) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % gqa_group_size;
|
||||
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
||||
using LoadEmbT = AlignedVector<float, 1>;
|
||||
LoadKVResT cache_vec;
|
||||
LoadT src_vec1, src_vec2;
|
||||
LoadBiasT bias_vec1, bias_vec2;
|
||||
LoadOutScaleT out_scale_vec1, out_scale_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
const T* qkv_now = quant_qkv + token_id * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - gqa_group_size;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
}
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left);
|
||||
bias_vec1[1] = static_cast<T>(input_right);
|
||||
}
|
||||
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec2[0] =
|
||||
static_cast<T>(tmp1);
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left);
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
}
|
||||
if (k_norm_weight) {
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
bias_vec1[i] = static_cast<T>(static_cast<float>(bias_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
bias_vec2[i] = static_cast<T>(static_cast<float>(bias_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(bias_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(bias_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const int start_block_16 =
|
||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
||||
const uint32_t tgt_cache_idx =
|
||||
block_idx * gqa_group_size * block_size * HeadDim +
|
||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * gqa_group_size * HeadDim * block_size +
|
||||
kv_head_idx * HeadDim * block_size +
|
||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
||||
block_offset % 8 / 2 * 4 // per 4
|
||||
+ block_offset % 16 / 8 * 2 // per 2
|
||||
+ block_offset % 2; // per 1
|
||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
|
||||
@@ -175,6 +175,78 @@ void append_speculate_cache_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void append_speculate_cache_fp8_dynamic_rope(const T* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* batch_id_per_token,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* cache_k_scale,
|
||||
T* cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const int token_num,
|
||||
const cudaStream_t& stream,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(token_num, all_warps / num_warps);
|
||||
|
||||
append_clear_cache_int8_block<4>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(key_cache,
|
||||
value_cache,
|
||||
seq_lens,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
kv_num_heads);
|
||||
append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel<T, 4, 0, 128, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(qkv,
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
cache_k_scale,
|
||||
cache_v_scale,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE, bool IsFP8=false>
|
||||
void append_speculate_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
@@ -459,6 +531,35 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
reinterpret_cast<const float*>(k_norm_weight.get().data<float>()),
|
||||
rms_norm_eps,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
append_speculate_cache_fp8_dynamic_rope(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps
|
||||
);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
@@ -561,6 +662,35 @@ void SpeculateWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
append_speculate_cache_fp8_dynamic_rope(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_v_scale.get().data<T>())),
|
||||
nullptr, // q_norm_weight
|
||||
nullptr, // k_norm_weight
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
token_nums,
|
||||
stream,
|
||||
rope_3d,
|
||||
rms_norm_eps
|
||||
);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_speculate_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
|
||||
Reference in New Issue
Block a user