mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
support qk norm (#3145)
This commit is contained in:
@@ -73,6 +73,9 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
const bool rope_3d,
|
const bool rope_3d,
|
||||||
@@ -223,7 +226,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
main_stream,
|
main_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache));
|
const_cast<paddle::Tensor*>(&value_cache),
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps);
|
||||||
};
|
};
|
||||||
|
|
||||||
if (qkv_out_scales) {
|
if (qkv_out_scales) {
|
||||||
@@ -339,7 +345,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
exec_stream,
|
exec_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache));
|
const_cast<paddle::Tensor*>(&value_cache),
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps);
|
||||||
} else {
|
} else {
|
||||||
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
DecoderWriteCacheWithRoPEKernel<data_t, data_t>(
|
||||||
meta_data,
|
meta_data,
|
||||||
@@ -363,7 +372,10 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
|
|||||||
exec_stream,
|
exec_stream,
|
||||||
&qkv_out,
|
&qkv_out,
|
||||||
const_cast<paddle::Tensor*>(&key_cache),
|
const_cast<paddle::Tensor*>(&key_cache),
|
||||||
const_cast<paddle::Tensor*>(&value_cache));
|
const_cast<paddle::Tensor*>(&value_cache),
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -430,6 +442,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
const paddle::optional<paddle::Tensor>& out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
const paddle::optional<paddle::Tensor>& out_linear_smooths,
|
||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
@@ -500,6 +515,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
out_linear_shifts,
|
out_linear_shifts,
|
||||||
out_linear_smooths,
|
out_linear_smooths,
|
||||||
kv_signal_data,
|
kv_signal_data,
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps,
|
||||||
cache_quant_type_str,
|
cache_quant_type_str,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
rope_3d,
|
rope_3d,
|
||||||
@@ -577,6 +595,9 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
|
|||||||
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
const paddle::optional<std::vector<int64_t>>& out_linear_shifts_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape,
|
||||||
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
const paddle::optional<std::vector<int64_t>>& kv_signal_data_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& q_norm_weight_shape,
|
||||||
|
const paddle::optional<std::vector<int64_t>>& k_norm_weight_shape,
|
||||||
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
@@ -637,6 +658,9 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
|
|||||||
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
const paddle::optional<paddle::DataType>& out_linear_shifts_dtype,
|
||||||
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
const paddle::optional<paddle::DataType>& out_linear_smooths_dtype,
|
||||||
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
const paddle::optional<paddle::DataType>& kv_signal_data_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& q_norm_weight_dtype,
|
||||||
|
const paddle::optional<paddle::DataType>& k_norm_weight_dtype,
|
||||||
|
const float rms_norm_eps,
|
||||||
const std::string& compute_dtype,
|
const std::string& compute_dtype,
|
||||||
const std::string& cache_quant_type_str,
|
const std::string& cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style,
|
const bool use_neox_rotary_style,
|
||||||
@@ -714,7 +738,9 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
paddle::Optional("cache_v_zp"),
|
paddle::Optional("cache_v_zp"),
|
||||||
paddle::Optional("out_linear_shifts"),
|
paddle::Optional("out_linear_shifts"),
|
||||||
paddle::Optional("out_linear_smooths"),
|
paddle::Optional("out_linear_smooths"),
|
||||||
paddle::Optional("kv_signal_data")})
|
paddle::Optional("kv_signal_data"),
|
||||||
|
paddle::Optional("q_norm_weight"),
|
||||||
|
paddle::Optional("k_norm_weight")})
|
||||||
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
.Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"})
|
||||||
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
.SetInplaceMap({{"key_cache", "key_cache_out"},
|
||||||
{"value_cache", "value_cache_out"}})
|
{"value_cache", "value_cache_out"}})
|
||||||
@@ -732,7 +758,8 @@ PD_BUILD_STATIC_OP(append_attention)
|
|||||||
"encoder_max_partition_size: int",
|
"encoder_max_partition_size: int",
|
||||||
"speculate_max_draft_token_num: int",
|
"speculate_max_draft_token_num: int",
|
||||||
"causal: bool",
|
"causal: bool",
|
||||||
"speculate_decoder: bool"})
|
"speculate_decoder: bool",
|
||||||
|
"rms_norm_eps: float"})
|
||||||
.SetKernelFn(PD_KERNEL(AppendAttention))
|
.SetKernelFn(PD_KERNEL(AppendAttention))
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
.SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape))
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
.SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype));
|
||||||
|
@@ -18,6 +18,142 @@
|
|||||||
#include "mma_tensor_op.cuh"
|
#include "mma_tensor_op.cuh"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
|
|
||||||
|
template <typename T, int VecSize = 1>
|
||||||
|
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||||
|
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||||
|
// head_size]
|
||||||
|
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||||
|
// head_size // 2]
|
||||||
|
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||||
|
// head_size // 2]
|
||||||
|
T* __restrict__ qkv_out,
|
||||||
|
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||||
|
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||||
|
const int* __restrict__ cu_seqlens_q,
|
||||||
|
const int* __restrict__ seq_lens, // [bsz]
|
||||||
|
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||||
|
const float* __restrict__ cos_emb,
|
||||||
|
const float* __restrict__ sin_emb,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_blocks_per_seq,
|
||||||
|
const int num_heads,
|
||||||
|
const int head_size,
|
||||||
|
const int block_size,
|
||||||
|
const uint32_t elem_cnt,
|
||||||
|
const int kv_num_heads,
|
||||||
|
const bool rope_3d,
|
||||||
|
const T* q_norm_weight,
|
||||||
|
const T* k_norm_weight,
|
||||||
|
const float rms_norm_eps) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||||
|
using LoadKVT = AlignedVector<T, VecSize>;
|
||||||
|
constexpr int HalfVecSize = VecSize / 2;
|
||||||
|
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||||
|
LoadT src_vec;
|
||||||
|
LoadBiasT out_vec;
|
||||||
|
LoadKVT cache_vec;
|
||||||
|
LoadEmbT cos_emb_vec;
|
||||||
|
LoadEmbT sin_emb_vec;
|
||||||
|
|
||||||
|
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
int64_t all_warp_num = gridDim.x * blockDim.x;
|
||||||
|
int64_t all_head_dim = elem_cnt / head_size;
|
||||||
|
|
||||||
|
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||||
|
// const int64_t offset = 2 * hidden_size;
|
||||||
|
const int half_head_size = head_size / 2;
|
||||||
|
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
|
||||||
|
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
|
||||||
|
const int ori_bi = linear_index / hidden_size;
|
||||||
|
const int bias = linear_index % hidden_size;
|
||||||
|
const int hi = bias / head_size; // q + k + v
|
||||||
|
const int h_bias = bias % head_size;
|
||||||
|
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||||
|
if (seq_lens_encoder[ori_bi] > 0) return;
|
||||||
|
const int write_seq_id = seq_lens[ori_bi];
|
||||||
|
if (write_seq_id == 0) continue;
|
||||||
|
|
||||||
|
const int* block_table_now = nullptr;
|
||||||
|
|
||||||
|
block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||||
|
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||||
|
const int block_offset = write_seq_id % block_size;
|
||||||
|
const uint32_t ori_idx =
|
||||||
|
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||||
|
|
||||||
|
const int bias_idx = hi * head_size + h_bias;
|
||||||
|
Load<T, VecSize>(&quant_qkv[ori_idx], &src_vec);
|
||||||
|
if (hi < num_heads + kv_num_heads) {
|
||||||
|
// q k rope
|
||||||
|
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||||
|
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||||
|
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||||
|
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||||
|
}
|
||||||
|
float thread_m2 = 0.0f;
|
||||||
|
float warp_m2 = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < HalfVecSize; i++) {
|
||||||
|
// dequant + add_bias + rope
|
||||||
|
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||||
|
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||||
|
|
||||||
|
if (hi < num_heads + kv_num_heads) {
|
||||||
|
const float cos_tmp = cos_emb_vec[i];
|
||||||
|
const float sin_tmp = sin_emb_vec[i];
|
||||||
|
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||||
|
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||||
|
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||||
|
out_vec[2 * i] =
|
||||||
|
static_cast<T>(tmp1);
|
||||||
|
out_vec[2 * i + 1] =
|
||||||
|
static_cast<T>(tmp2);
|
||||||
|
} else {
|
||||||
|
out_vec[2 * i] = src_vec[2 * i];
|
||||||
|
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hi < (num_heads + kv_num_heads)) { // q k
|
||||||
|
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||||
|
float row_variance =
|
||||||
|
max(warp_m2 / head_size, 0.0f);
|
||||||
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
|
LoadT q_norm_vec, k_norm_vec;
|
||||||
|
if (hi < num_heads) { // q
|
||||||
|
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||||
|
}
|
||||||
|
} else { // k
|
||||||
|
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (hi < num_heads) {
|
||||||
|
// write q
|
||||||
|
Store<T, VecSize>(out_vec, &qkv_out[ori_idx]);
|
||||||
|
} else {
|
||||||
|
// quant + write k/v
|
||||||
|
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
|
||||||
|
const uint32_t tgt_idx =
|
||||||
|
block_idx * kv_num_heads * block_size * head_size +
|
||||||
|
kv_head_idx * block_size * head_size + block_offset * head_size +
|
||||||
|
h_bias;
|
||||||
|
if (hi < num_heads + kv_num_heads) {
|
||||||
|
Store<T, VecSize>(out_vec, &key_cache[tgt_idx]);
|
||||||
|
} else {
|
||||||
|
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int VecSize = 1>
|
template <typename T, int VecSize = 1>
|
||||||
__global__ void append_decode_cache_T_rope_kernel(
|
__global__ void append_decode_cache_T_rope_kernel(
|
||||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||||
|
@@ -15,6 +15,70 @@
|
|||||||
#include "decoder_write_cache_with_rope_kernel.h"
|
#include "decoder_write_cache_with_rope_kernel.h"
|
||||||
#include "utils.cuh"
|
#include "utils.cuh"
|
||||||
|
|
||||||
|
template <typename T, typename QKV_TYPE>
|
||||||
|
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||||
|
T* key_cache,
|
||||||
|
T* value_cache,
|
||||||
|
T* qkv_out,
|
||||||
|
const int* block_tables,
|
||||||
|
const int* batch_id_per_token,
|
||||||
|
const int* cu_seqlens_q,
|
||||||
|
const int* seq_lens,
|
||||||
|
const int* seq_lens_encoder,
|
||||||
|
const float* cos_emb,
|
||||||
|
const float* sin_emb,
|
||||||
|
const float* qkv_out_scales,
|
||||||
|
const T* qkv_biases,
|
||||||
|
const int max_seq_len,
|
||||||
|
const int max_blocks_per_seq,
|
||||||
|
const int num_heads,
|
||||||
|
const int kv_num_heads,
|
||||||
|
const int dim_head,
|
||||||
|
const int block_size,
|
||||||
|
const int bsz,
|
||||||
|
const cudaStream_t& stream,
|
||||||
|
const bool use_neox_style,
|
||||||
|
const bool rope_3d,
|
||||||
|
const T* q_norm_weight,
|
||||||
|
const T* k_norm_weight,
|
||||||
|
const float rms_norm_eps) {
|
||||||
|
const uint32_t elem_nums =
|
||||||
|
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||||
|
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||||
|
assert(dim_head == 128 && "dim_head must be 128");
|
||||||
|
constexpr int HEAD_DIM = 128;
|
||||||
|
|
||||||
|
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||||
|
const int pack_num = elem_nums / PackSize;
|
||||||
|
const int blocksize = 128;
|
||||||
|
int grid_size = 1;
|
||||||
|
GetNumBlocks<128>(pack_num, &grid_size);
|
||||||
|
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
|
||||||
|
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
|
||||||
|
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
qkv_out,
|
||||||
|
block_tables,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_encoder,
|
||||||
|
cos_emb,
|
||||||
|
sin_emb,
|
||||||
|
max_seq_len,
|
||||||
|
max_blocks_per_seq,
|
||||||
|
num_heads,
|
||||||
|
dim_head,
|
||||||
|
block_size,
|
||||||
|
elem_nums,
|
||||||
|
kv_num_heads,
|
||||||
|
rope_3d,
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename QKV_TYPE>
|
template <typename T, typename QKV_TYPE>
|
||||||
void append_decode_cache_rope(const QKV_TYPE* qkv,
|
void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||||
T* key_cache,
|
T* key_cache,
|
||||||
@@ -441,7 +505,10 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out) {
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps) {
|
||||||
typedef cascade_attn_type_traits<T> traits_;
|
typedef cascade_attn_type_traits<T> traits_;
|
||||||
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
typedef cascade_attn_type_traits<QKV_TYPE> qkt_nv_type_;
|
||||||
typedef typename traits_::type DataType_;
|
typedef typename traits_::type DataType_;
|
||||||
@@ -464,6 +531,43 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (q_norm_weight && k_norm_weight) {
|
||||||
|
if (cache_quant_type_str == "none") {
|
||||||
|
append_decode_cache_rope_qk_norm(
|
||||||
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
|
reinterpret_cast<DataType_*>(key_cache_out->data<T>()),
|
||||||
|
reinterpret_cast<DataType_*>(value_cache_out->data<T>()),
|
||||||
|
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||||
|
block_tables.data<int>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
seq_lens.data<int>(),
|
||||||
|
seq_lens_encoder.data<int>(),
|
||||||
|
cos_emb,
|
||||||
|
sin_emb,
|
||||||
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
|
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||||
|
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||||
|
: nullptr,
|
||||||
|
max_seq_len,
|
||||||
|
max_blocks_per_seq,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads,
|
||||||
|
dim_head,
|
||||||
|
block_size,
|
||||||
|
bsz,
|
||||||
|
stream,
|
||||||
|
use_neox_rotary_style,
|
||||||
|
rope_3d,
|
||||||
|
reinterpret_cast<const DataType_*>(q_norm_weight.get().data<T>()),
|
||||||
|
reinterpret_cast<const DataType_*>(k_norm_weight.get().data<T>()),
|
||||||
|
rms_norm_eps);
|
||||||
|
} else {
|
||||||
|
PD_THROW(
|
||||||
|
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (cache_quant_type_str == "none") {
|
if (cache_quant_type_str == "none") {
|
||||||
append_decode_cache_rope(
|
append_decode_cache_rope(
|
||||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||||
@@ -640,6 +744,7 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
|
"cache_quant_type_str should be one of [none, cache_int8, cache_fp8 "
|
||||||
"cache_int4_zp]");
|
"cache_int4_zp]");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -667,7 +772,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
|
||||||
template void
|
template void
|
||||||
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
||||||
@@ -694,7 +802,10 @@ DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
|
||||||
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
@@ -720,7 +831,10 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
|
||||||
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
||||||
const AppendAttnMetaData& meta_data,
|
const AppendAttnMetaData& meta_data,
|
||||||
@@ -746,4 +860,7 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
@@ -40,4 +40,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight, const float rms_norm_eps);
|
||||||
|
@@ -358,7 +358,7 @@ __global__ void GQAVariableLengthRotaryKernel(
|
|||||||
linear_index < elem_cnt;
|
linear_index < elem_cnt;
|
||||||
linear_index += step) {
|
linear_index += step) {
|
||||||
const int token_idx = linear_index / offset;
|
const int token_idx = linear_index / offset;
|
||||||
const int ori_bi = batch_id_per_token[token_idx];;
|
const int ori_bi = batch_id_per_token[token_idx];
|
||||||
if (seq_lens[ori_bi] == 0) continue;
|
if (seq_lens[ori_bi] == 0) continue;
|
||||||
const int bias = linear_index % offset;
|
const int bias = linear_index % offset;
|
||||||
const int hi = bias / last_dim;
|
const int hi = bias / last_dim;
|
||||||
@@ -405,6 +405,94 @@ __global__ void GQAVariableLengthRotaryKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T, int VecSize = 1>
|
||||||
|
__global__ void GQAVariableLengthRotaryQKNormKernel(
|
||||||
|
const T *qkv,
|
||||||
|
const float *cos_emb,
|
||||||
|
const float *sin_emb,
|
||||||
|
const int *batch_id_per_token,
|
||||||
|
const int *cu_seqlens_q,
|
||||||
|
const int *seq_lens,
|
||||||
|
const int *seq_lens_decoder,
|
||||||
|
T *qkv_out,
|
||||||
|
const int64_t elem_cnt,
|
||||||
|
const int q_num_head,
|
||||||
|
const int kv_num_head,
|
||||||
|
const int seq_len,
|
||||||
|
const int last_dim,
|
||||||
|
const bool rope_3d,
|
||||||
|
const T* q_norm_weight,
|
||||||
|
const T* k_norm_weight,
|
||||||
|
const float rms_norm_eps
|
||||||
|
) {
|
||||||
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
|
constexpr int HalfVecSize = VecSize / 2;
|
||||||
|
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||||
|
LoadT src_vec;
|
||||||
|
LoadEmbT cos_emb_vec;
|
||||||
|
LoadEmbT sin_emb_vec;
|
||||||
|
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
int64_t all_warp_num = gridDim.x * blockDim.x;
|
||||||
|
const int half_lastdim = last_dim / 2;
|
||||||
|
const int offset = (q_num_head + kv_num_head) * last_dim;
|
||||||
|
const int all_head_num = elem_cnt / last_dim;
|
||||||
|
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
|
||||||
|
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
|
||||||
|
const int token_idx = linear_index / offset;
|
||||||
|
const int ori_bi = batch_id_per_token[token_idx];
|
||||||
|
if (seq_lens[ori_bi] == 0) continue;
|
||||||
|
const int bias = linear_index % offset;
|
||||||
|
const int hi = bias / last_dim;
|
||||||
|
const int h_bias = bias % last_dim;
|
||||||
|
|
||||||
|
const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||||
|
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||||
|
const int64_t base_idx =
|
||||||
|
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||||
|
h_bias;
|
||||||
|
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||||
|
|
||||||
|
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||||
|
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||||
|
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||||
|
|
||||||
|
float thread_m2 = 0.0f;
|
||||||
|
float warp_m2 = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < HalfVecSize; i++) {
|
||||||
|
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||||
|
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||||
|
const float cos_tmp = cos_emb_vec[i];
|
||||||
|
const float sin_tmp = sin_emb_vec[i];
|
||||||
|
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||||
|
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||||
|
src_vec[2 * i] = static_cast<T>(tmp1);
|
||||||
|
src_vec[2 * i + 1] = static_cast<T>(tmp2);
|
||||||
|
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||||
|
}
|
||||||
|
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||||
|
float row_variance =
|
||||||
|
max(warp_m2 / last_dim, 0.0f);
|
||||||
|
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||||
|
LoadT q_norm_vec, k_norm_vec;
|
||||||
|
if (hi < q_num_head) {
|
||||||
|
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
|
||||||
|
for (int i = 0; i < VecSize; i++) {
|
||||||
|
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int VecSize = 1>
|
template <typename T, int VecSize = 1>
|
||||||
__global__ void GQAVariableLengthRotaryKernel(
|
__global__ void GQAVariableLengthRotaryKernel(
|
||||||
const T *qkv,
|
const T *qkv,
|
||||||
@@ -1568,6 +1656,66 @@ void rotary_qk_variable(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename QKV_TYPE>
|
||||||
|
void gqa_rotary_qk_norm_variable(
|
||||||
|
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||||
|
const QKV_TYPE *qkv_input, // qkv
|
||||||
|
const float *qkv_out_scales, // [3, num_head, dim_head]
|
||||||
|
const T *qkv_bias,
|
||||||
|
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||||
|
const int *batch_id_per_token,
|
||||||
|
const int *cu_seqlens_q,
|
||||||
|
const int *seq_lens,
|
||||||
|
const int *seq_lens_decoder,
|
||||||
|
const int token_num,
|
||||||
|
const int num_heads,
|
||||||
|
const int kv_num_heads,
|
||||||
|
const int seq_len,
|
||||||
|
const int input_output_len,
|
||||||
|
const int dim_head,
|
||||||
|
const cudaStream_t &stream,
|
||||||
|
bool use_neox_style = false,
|
||||||
|
bool rope_3d = false,
|
||||||
|
const T *q_norm_weight = nullptr,
|
||||||
|
const T *k_norm_weight = nullptr,
|
||||||
|
const float rms_norm_eps = 1e-6) {
|
||||||
|
int64_t elem_nums =
|
||||||
|
qkv_out_scales
|
||||||
|
? token_num * (num_heads + 2 * kv_num_heads) * dim_head
|
||||||
|
: token_num * (num_heads + kv_num_heads) * dim_head; // for all q k v
|
||||||
|
assert(dim_head == 128 && "dim_head must be 128");
|
||||||
|
constexpr int HEAD_DIM = 128;
|
||||||
|
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||||
|
const int pack_num = elem_nums / PackSize;
|
||||||
|
const int blocksize = 128;
|
||||||
|
int grid_size = 1;
|
||||||
|
GetNumBlocks<128>(pack_num, &grid_size);
|
||||||
|
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
|
||||||
|
|
||||||
|
const float *cos_emb = rotary_emb;
|
||||||
|
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||||
|
|
||||||
|
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
|
||||||
|
<<<grid_size, Blocks, 0, stream>>>(
|
||||||
|
reinterpret_cast<const T *>(qkv_input),
|
||||||
|
cos_emb,
|
||||||
|
sin_emb,
|
||||||
|
batch_id_per_token,
|
||||||
|
cu_seqlens_q,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_decoder,
|
||||||
|
qkv_out,
|
||||||
|
elem_nums,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads,
|
||||||
|
seq_len,
|
||||||
|
dim_head,
|
||||||
|
rope_3d,
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, typename QKV_TYPE>
|
template <typename T, typename QKV_TYPE>
|
||||||
void gqa_rotary_qk_variable(
|
void gqa_rotary_qk_variable(
|
||||||
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
T *qkv_out, // [token_num, 3, num_head, dim_head]
|
||||||
|
@@ -46,7 +46,10 @@ void EncoderWriteCacheWithRopeKernel(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out) {
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps) {
|
||||||
auto token_num = meta_data.token_nums;
|
auto token_num = meta_data.token_nums;
|
||||||
auto num_heads = meta_data.q_num_heads;
|
auto num_heads = meta_data.q_num_heads;
|
||||||
auto kv_num_heads = meta_data.kv_num_heads;
|
auto kv_num_heads = meta_data.kv_num_heads;
|
||||||
@@ -56,6 +59,35 @@ void EncoderWriteCacheWithRopeKernel(
|
|||||||
is_scale_channel_wise = true;
|
is_scale_channel_wise = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (q_norm_weight && k_norm_weight) {
|
||||||
|
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
|
||||||
|
gqa_rotary_qk_norm_variable(
|
||||||
|
qkv_out->data<T>(),
|
||||||
|
qkv.data<QKV_TYPE>(),
|
||||||
|
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||||
|
qkv_biases ? qkv_biases.get().data<T>() : nullptr,
|
||||||
|
rotary_embs.get().data<float>(),
|
||||||
|
batch_id_per_token.data<int>(),
|
||||||
|
cu_seqlens_q.data<int>(),
|
||||||
|
seq_lens_encoder.data<int>(),
|
||||||
|
seq_lens_decoder.data<int>(),
|
||||||
|
token_num,
|
||||||
|
num_heads,
|
||||||
|
kv_num_heads,
|
||||||
|
max_seq_len,
|
||||||
|
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
|
||||||
|
head_dim,
|
||||||
|
stream,
|
||||||
|
use_neox_style,
|
||||||
|
rope_3d,
|
||||||
|
q_norm_weight ? q_norm_weight.get().data<T>() : nullptr,
|
||||||
|
k_norm_weight ? k_norm_weight.get().data<T>() : nullptr,
|
||||||
|
rms_norm_eps);
|
||||||
|
} else {
|
||||||
|
PD_THROW(
|
||||||
|
"gqa_rotary_qk_norm_variable only support gqa mode. channel wise scale and neox style are not supported");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (num_heads == kv_num_heads) {
|
if (num_heads == kv_num_heads) {
|
||||||
rotary_qk_variable(
|
rotary_qk_variable(
|
||||||
qkv_out->data<T>(),
|
qkv_out->data<T>(),
|
||||||
@@ -121,6 +153,7 @@ void EncoderWriteCacheWithRopeKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
const uint32_t block_size = meta_data.block_size;
|
const uint32_t block_size = meta_data.block_size;
|
||||||
if (cache_quant_type_str == "none") {
|
if (cache_quant_type_str == "none") {
|
||||||
CascadeAppendWriteCacheKVQKV<T>(meta_data,
|
CascadeAppendWriteCacheKVQKV<T>(meta_data,
|
||||||
|
@@ -43,4 +43,7 @@ EncoderWriteCacheWithRopeKernel<paddle::bfloat16, paddle::bfloat16>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::bfloat16, int>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, paddle::float16>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
@@ -42,4 +42,7 @@ template void EncoderWriteCacheWithRopeKernel<paddle::float16, int>(
|
|||||||
cudaStream_t& stream,
|
cudaStream_t& stream,
|
||||||
paddle::Tensor* qkv_out,
|
paddle::Tensor* qkv_out,
|
||||||
paddle::Tensor* key_cache_out,
|
paddle::Tensor* key_cache_out,
|
||||||
paddle::Tensor* value_cache_out);
|
paddle::Tensor* value_cache_out,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps);
|
||||||
|
@@ -559,3 +559,37 @@ template <typename T, bool IsFP8>inline __device__ static void convert_c8(T * re
|
|||||||
convert_int8(result, source);
|
convert_int8(result, source);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr int kWarpSize = 32;
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
inline __device__ void WelfordCombine1(T b_m2, T* m2) {
|
||||||
|
*m2 += b_m2;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, int thread_group_width = kWarpSize>
|
||||||
|
__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) {
|
||||||
|
*m2 = thread_m2;
|
||||||
|
for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) {
|
||||||
|
T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask);
|
||||||
|
WelfordCombine1(b_m2, m2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, int thread_group_width = kWarpSize>
|
||||||
|
__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) {
|
||||||
|
WelfordWarpReduce<T, thread_group_width>(thread_m2, m2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__inline__ __device__ T Rsqrt(T x);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ float Rsqrt<float>(float x) {
|
||||||
|
return rsqrt(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__inline__ __device__ double Rsqrt<double>(double x) {
|
||||||
|
return rsqrt(x);
|
||||||
|
}
|
||||||
|
@@ -78,6 +78,9 @@ std::vector<paddle::Tensor> AppendAttention(
|
|||||||
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
const paddle::optional<paddle::Tensor> &out_linear_shifts,
|
||||||
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
const paddle::optional<paddle::Tensor> &out_linear_smooths,
|
||||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||||
|
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||||
|
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||||
|
const float rms_norm_eps,
|
||||||
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
const std::string &compute_dtype, const std::string &cache_quant_type_str,
|
||||||
const bool use_neox_rotary_style, const bool rope_3d,
|
const bool use_neox_rotary_style, const bool rope_3d,
|
||||||
const int max_input_length, const float quant_max_bound,
|
const int max_input_length, const float quant_max_bound,
|
||||||
|
@@ -43,6 +43,11 @@
|
|||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
break; \
|
break; \
|
||||||
} \
|
} \
|
||||||
|
case 48: { \
|
||||||
|
constexpr size_t NUM_EXPERTS_PER_RANK = 48; \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
case 64: { \
|
case 64: { \
|
||||||
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
|
constexpr size_t NUM_EXPERTS_PER_RANK = 64; \
|
||||||
__VA_ARGS__ \
|
__VA_ARGS__ \
|
||||||
|
@@ -262,6 +262,9 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
layer.linear_shift,
|
layer.linear_shift,
|
||||||
layer.linear_smooth,
|
layer.linear_smooth,
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
|
getattr(layer, "q_norm_weight", None),
|
||||||
|
getattr(layer, "k_norm_weight", None),
|
||||||
|
getattr(layer, "rms_norm_eps", 1e-6),
|
||||||
metadata._fuse_kernel_compute_dtype,
|
metadata._fuse_kernel_compute_dtype,
|
||||||
getattr(layer, "cache_quant_type_str", "none"),
|
getattr(layer, "cache_quant_type_str", "none"),
|
||||||
layer.use_neox_rotary_style,
|
layer.use_neox_rotary_style,
|
||||||
|
@@ -28,6 +28,7 @@ from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethod
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||||
|
from fastdeploy.model_executor.layers.utils import get_tensor
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Layer):
|
class Attention(nn.Layer):
|
||||||
@@ -49,6 +50,7 @@ class Attention(nn.Layer):
|
|||||||
linear_smooth: paddle.Tensor = None,
|
linear_smooth: paddle.Tensor = None,
|
||||||
use_neox_rotary_style: bool = False,
|
use_neox_rotary_style: bool = False,
|
||||||
use_qk_norm: bool = False,
|
use_qk_norm: bool = False,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initializes `LMLayer` with the given parameters.
|
Initializes `LMLayer` with the given parameters.
|
||||||
@@ -63,6 +65,8 @@ class Attention(nn.Layer):
|
|||||||
prefix (str, optional): The name of current layer. Defaults to "".
|
prefix (str, optional): The name of current layer. Defaults to "".
|
||||||
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
|
linear_shift (Optional[paddle.Tensor], optional): The shift of linear. Defaults to None.
|
||||||
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
|
linear_smooth (Optional[paddle.Tensor], optional): The smooth of linear. Defaults to None.
|
||||||
|
use_qk_norm (bool, optional): Whether to apply rmsnorm on QA after rope. Defaults to False.
|
||||||
|
rms_norm_eps (float, optional): The epsilon of RMSNorm. Defaults to 1e-6.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the `v_head_dim` is less than 0.
|
ValueError: If the `v_head_dim` is less than 0.
|
||||||
@@ -102,6 +106,27 @@ class Attention(nn.Layer):
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
|
f"Attention is running in cache kv {self.kvcache_quant_method.cache_quant_config.quant_type} mode"
|
||||||
)
|
)
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm_key = f"{self.prefix}.q_norm"
|
||||||
|
self.k_norm_key = f"{self.prefix}.k_norm"
|
||||||
|
self.init_weight()
|
||||||
|
|
||||||
|
def init_weight(self):
|
||||||
|
self.q_norm_weight = self.create_parameter(
|
||||||
|
shape=[self.qk_head_dim],
|
||||||
|
dtype=self._dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.k_norm_weight = self.create_parameter(
|
||||||
|
shape=[self.qk_head_dim],
|
||||||
|
dtype=self._dtype,
|
||||||
|
is_bias=False,
|
||||||
|
default_initializer=paddle.nn.initializer.Constant(0),
|
||||||
|
)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
|
||||||
"""
|
"""
|
||||||
@@ -109,6 +134,11 @@ class Attention(nn.Layer):
|
|||||||
"""
|
"""
|
||||||
if self.kvcache_quant_method is not None:
|
if self.kvcache_quant_method is not None:
|
||||||
self.kvcache_quant_method.create_weights(self, state_dict)
|
self.kvcache_quant_method.create_weights(self, state_dict)
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.q_norm_key + ".weight")))
|
||||||
|
k_norm_weight_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.k_norm_key + ".weight")))
|
||||||
|
self.q_norm_weight.set_value(q_norm_weight_tensor)
|
||||||
|
self.k_norm_weight.set_value(k_norm_weight_tensor)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@@ -60,6 +60,9 @@ def append_attention(
|
|||||||
linear_shift: Optional[paddle.Tensor] = None,
|
linear_shift: Optional[paddle.Tensor] = None,
|
||||||
linear_smooth: Optional[paddle.Tensor] = None,
|
linear_smooth: Optional[paddle.Tensor] = None,
|
||||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||||
|
q_norm_weight: Optional[paddle.Tensor] = None,
|
||||||
|
k_norm_weight: Optional[paddle.Tensor] = None,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
compute_type: str = "bf16",
|
compute_type: str = "bf16",
|
||||||
cache_quant_type: str = "none",
|
cache_quant_type: str = "none",
|
||||||
use_neox_rotary_style: bool = False,
|
use_neox_rotary_style: bool = False,
|
||||||
@@ -114,6 +117,9 @@ def append_attention(
|
|||||||
linear_shift,
|
linear_shift,
|
||||||
linear_smooth,
|
linear_smooth,
|
||||||
kv_signal_data,
|
kv_signal_data,
|
||||||
|
q_norm_weight,
|
||||||
|
k_norm_weight,
|
||||||
|
rms_norm_eps,
|
||||||
compute_type,
|
compute_type,
|
||||||
cache_quant_type,
|
cache_quant_type,
|
||||||
use_neox_rotary_style,
|
use_neox_rotary_style,
|
||||||
|
@@ -17,6 +17,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle.incubate.nn.functional import fused_rms_norm
|
||||||
|
|
||||||
paddle.seed(10)
|
paddle.seed(10)
|
||||||
|
|
||||||
@@ -157,6 +158,8 @@ def naive_attention_impl(
|
|||||||
cache_k_dequant_scales=None,
|
cache_k_dequant_scales=None,
|
||||||
cache_v_dequant_scales=None,
|
cache_v_dequant_scales=None,
|
||||||
use_cachekv_int8="None",
|
use_cachekv_int8="None",
|
||||||
|
q_norm_weight=None,
|
||||||
|
k_norm_weight=None,
|
||||||
):
|
):
|
||||||
batch = query.shape[0]
|
batch = query.shape[0]
|
||||||
heads = query.shape[1]
|
heads = query.shape[1]
|
||||||
@@ -244,6 +247,27 @@ def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, dim_head
|
|||||||
return q, k, v, qkv
|
return q, k, v, qkv
|
||||||
|
|
||||||
|
|
||||||
|
def apply_qk_norm(head_dim, dtype, q, k):
|
||||||
|
q_norm_weight = np.random.random([head_dim]) / 10
|
||||||
|
k_norm_weight = np.random.random([head_dim]) / 10
|
||||||
|
q_norm_weight_tensor = paddle.to_tensor(q_norm_weight, dtype=dtype)
|
||||||
|
k_norm_weight_tensor = paddle.to_tensor(k_norm_weight, dtype=dtype)
|
||||||
|
print("q:", q.shape)
|
||||||
|
print("k:", k.shape)
|
||||||
|
bs, q_num_head, seq_len, dim_head = q.shape
|
||||||
|
_, kv_num_head, _, _ = k.shape
|
||||||
|
|
||||||
|
q = q.reshape([-1, head_dim])
|
||||||
|
k = k.reshape([-1, head_dim])
|
||||||
|
print("q:", q)
|
||||||
|
q = fused_rms_norm(q, q_norm_weight_tensor, None, 1e-5)[0]
|
||||||
|
print("q after norm:", q)
|
||||||
|
k = fused_rms_norm(k, k_norm_weight_tensor, None, 1e-5)[0]
|
||||||
|
q = q.reshape([-1, q_num_head, seq_len, dim_head])
|
||||||
|
k = k.reshape([-1, kv_num_head, seq_len, dim_head])
|
||||||
|
return q, k, q_norm_weight_tensor, k_norm_weight_tensor
|
||||||
|
|
||||||
|
|
||||||
def split_query_by_phase(
|
def split_query_by_phase(
|
||||||
query,
|
query,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
@@ -324,6 +348,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
self.softmax_scale = self.dim_head**-0.5
|
self.softmax_scale = self.dim_head**-0.5
|
||||||
self.rope_theta = 10000
|
self.rope_theta = 10000
|
||||||
self.dtype = "float16"
|
self.dtype = "float16"
|
||||||
|
self.use_qk_norm = True
|
||||||
self.init_tensor()
|
self.init_tensor()
|
||||||
|
|
||||||
def init_tensor(self):
|
def init_tensor(self):
|
||||||
@@ -394,6 +419,11 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
|
q, k = self.rope._apply_rope(self.rope_emb, q, k, causal=True)
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q, k, q_norm_weight, k_norm_weight = apply_qk_norm(self.dim_head, self.dtype, q, k)
|
||||||
|
else:
|
||||||
|
q_norm_weight = None
|
||||||
|
k_norm_weight = None
|
||||||
out_ = naive_attention_impl(
|
out_ = naive_attention_impl(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@@ -476,6 +506,9 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
|||||||
None, # linear_shift
|
None, # linear_shift
|
||||||
None, # linear_smooth
|
None, # linear_smooth
|
||||||
None, # kv_signal_data
|
None, # kv_signal_data
|
||||||
|
q_norm_weight, # q_norm_weight
|
||||||
|
k_norm_weight, # k_norm_weight
|
||||||
|
1e-6,
|
||||||
"fp16",
|
"fp16",
|
||||||
"none", # cache_quant_type
|
"none", # cache_quant_type
|
||||||
self.use_neox_rotary_style,
|
self.use_neox_rotary_style,
|
||||||
@@ -580,6 +613,7 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
|||||||
self.softmax_scale = self.dim_head**-0.5
|
self.softmax_scale = self.dim_head**-0.5
|
||||||
self.rope_theta = 10000
|
self.rope_theta = 10000
|
||||||
self.dtype = "float16"
|
self.dtype = "float16"
|
||||||
|
self.use_qk_norm = False
|
||||||
self.init_tensor()
|
self.init_tensor()
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user