mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
FA3 support qwen3 (#5441)
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "qwen3_rope.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
@@ -173,7 +174,7 @@ void gqa_rotary_qk_split_variable(
|
||||
T *k,
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, head_dim / 2]
|
||||
const float *rotary_emb, // [2, 1, seq_len, 1, head_dim / 2]
|
||||
const float *q_norm_weight,
|
||||
const float *k_norm_weight,
|
||||
const int *batch_id_per_token,
|
||||
@@ -1136,6 +1137,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const float rms_norm_eps,
|
||||
const bool use_neox_rotary_style,
|
||||
const std::string &cache_quant_type,
|
||||
const bool rope_3d) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
@@ -1157,6 +1159,24 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
|
||||
const float softmax_scale = 1.f / sqrt(head_dim);
|
||||
|
||||
PADDLE_ENFORCE_EQ(batch_id_per_token.dims().size(), 1);
|
||||
PADDLE_ENFORCE_EQ(batch_id_per_token.dims()[0], token_num);
|
||||
|
||||
if (!rope_3d) {
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims().size(), 5);
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[0], 2);
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[1], 1);
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[2], max_seq_len);
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[3], 1);
|
||||
if (use_neox_rotary_style) {
|
||||
// Note(ZKK) Qwen3 like model
|
||||
// the [0,head_dim/2), [head_dim/2,head_dim) data are totally same!
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim);
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 2);
|
||||
}
|
||||
}
|
||||
|
||||
AppendAttnMetaData meta_data;
|
||||
meta_data.token_nums = token_num;
|
||||
meta_data.kv_num_heads = kv_num_heads;
|
||||
@@ -1175,30 +1195,49 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
paddle::Tensor v = GetEmptyTensor(
|
||||
{kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place());
|
||||
|
||||
// rope
|
||||
gqa_rotary_qk_split_variable<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps,
|
||||
stream);
|
||||
if (use_neox_rotary_style) {
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t>(qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
stream);
|
||||
} else {
|
||||
gqa_rotary_qk_split_variable<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps,
|
||||
stream);
|
||||
}
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
AppendCacheKV<data_t, 128, 64>(key_cache,
|
||||
@@ -1347,6 +1386,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
.Attrs({"kv_token_num: int",
|
||||
"max_seq_len: int",
|
||||
"rms_norm_eps: float",
|
||||
"use_neox_rotary_style: bool",
|
||||
"cache_quant_type: std::string",
|
||||
"rope_3d: bool"})
|
||||
.SetKernelFn(PD_KERNEL(GQARopeWriteCacheKernel));
|
||||
|
||||
167
custom_ops/gpu_ops/append_attn/qwen3_rope.h
Normal file
167
custom_ops/gpu_ops/append_attn/qwen3_rope.h
Normal file
@@ -0,0 +1,167 @@
|
||||
#include "encoder_write_cache_with_rope_impl.cuh"
|
||||
#include "helper.h"
|
||||
#include "paddle/extension.h"
|
||||
#include "paddle/phi/backends/context_pool.h"
|
||||
#include "paddle/phi/core/memory/memcpy.h"
|
||||
#include "remote_cache_kv_ipc.h"
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_k,
|
||||
T *qkv_out,
|
||||
T *q,
|
||||
T *k,
|
||||
T *v,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int max_model_len,
|
||||
const int head_dim) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
const int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const int offset = (q_num_head + kv_num_head * 2) * (head_dim / 2);
|
||||
const int64_t loop_times = elem_cnt / 2;
|
||||
|
||||
for (int64_t linear_index = global_thread_idx * VecSize;
|
||||
linear_index < loop_times;
|
||||
linear_index += gridDim.x * blockDim.x * VecSize) {
|
||||
const int token_idx = linear_index / offset;
|
||||
|
||||
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
|
||||
|
||||
int cache_kv_len = seq_lens_decoder[ori_bi];
|
||||
// 这里其实是不需要处理的,但是由于FA3的bug,所以必须!
|
||||
if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0;
|
||||
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / (head_dim / 2);
|
||||
const int h_bias = bias % (head_dim / 2);
|
||||
// we should handle token_idx, hi 头 的 h_bias 部分!
|
||||
|
||||
const int ori_seq_id =
|
||||
(token_idx - cu_seqlens_q[ori_bi]) +
|
||||
cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
|
||||
|
||||
const int half_headdim = head_dim / 2;
|
||||
const int64_t emb_idx = ori_seq_id * head_dim + h_bias; // embedding的id
|
||||
|
||||
const int64_t read_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
|
||||
LoadT src_vec0;
|
||||
LoadT src_vec1;
|
||||
|
||||
Load<T, VecSize>(&qkv[read_idx], &src_vec0);
|
||||
Load<T, VecSize>(&qkv[read_idx + 64], &src_vec1);
|
||||
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
int64_t base_split_idx;
|
||||
T *out_p = nullptr;
|
||||
if (hi < q_num_head) {
|
||||
base_split_idx =
|
||||
token_idx * q_num_head * head_dim + hi * head_dim + h_bias;
|
||||
out_p = q;
|
||||
} else if (hi < q_num_head + kv_num_head) {
|
||||
base_split_idx = kv_write_idx * kv_num_head * head_dim +
|
||||
(hi - q_num_head) * head_dim + h_bias;
|
||||
out_p = k;
|
||||
} else {
|
||||
out_p = v;
|
||||
base_split_idx = kv_write_idx * kv_num_head * head_dim +
|
||||
(hi - q_num_head - kv_num_head) * head_dim + h_bias;
|
||||
}
|
||||
|
||||
// TODO check this correct or not
|
||||
int64_t new_emb_idx = emb_idx;
|
||||
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
float input_left = static_cast<float>(src_vec0[i]);
|
||||
float input_right = static_cast<float>(src_vec1[i]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec0[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
src_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(src_vec0, &qkv_out[read_idx]);
|
||||
Store<T, VecSize>(src_vec0, &out_p[base_split_idx]);
|
||||
Store<T, VecSize>(src_vec1, &qkv_out[read_idx + 64]);
|
||||
Store<T, VecSize>(src_vec1, &out_p[base_split_idx + 64]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
|
||||
T *q,
|
||||
T *k,
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_q,
|
||||
const int *cu_seqlens_k,
|
||||
const int token_num,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int max_model_len,
|
||||
const int head_dim,
|
||||
const cudaStream_t &stream) {
|
||||
assert(head_dim == 128 && "head_dim must be 128");
|
||||
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim;
|
||||
|
||||
constexpr int HEAD_DIM = 128;
|
||||
constexpr int PackSize = 8;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_size(128);
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + max_model_len * head_dim;
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthRotarySplitKernel_Qwen3<T, PackSize>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
qkv_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim);
|
||||
}
|
||||
@@ -190,6 +190,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const float rms_norm_eps,
|
||||
const bool use_neox_rotary_style,
|
||||
const std::string& cache_quant_type,
|
||||
const bool rope_3d);
|
||||
|
||||
|
||||
@@ -102,7 +102,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashAttentionMetadata = None
|
||||
self.max_seq_len = fd_config.model_config.max_model_len
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
@@ -150,10 +149,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
|
||||
)
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
@@ -233,7 +228,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(forward_meta.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.attention_metadata = metadata
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
@@ -246,7 +241,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
metadata = forward_meta.attention_metadata
|
||||
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||
@@ -287,6 +282,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
layer.use_neox_rotary_style,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
self.rope_3d,
|
||||
)
|
||||
|
||||
@@ -57,13 +57,7 @@ class FlashMaskAttentionMetadata(AttentionMetadata):
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
@@ -173,9 +167,6 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashMaskAttentionMetadata()
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
@@ -265,14 +256,14 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.block_tables,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
@@ -291,6 +282,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
layer.use_neox_rotary_style,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
self.rope_3d,
|
||||
)
|
||||
@@ -299,7 +291,7 @@ class FlashMaskAttentionBackend(AttentionBackend):
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
forward_meta.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
forward_meta.seq_lens_encoder,
|
||||
res_encoder,
|
||||
|
||||
@@ -51,6 +51,7 @@ def gqa_rope_write_cache(
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
use_neox_rotary_style: bool = False,
|
||||
cache_quant_type: str = "none",
|
||||
rope_3d: bool = False,
|
||||
):
|
||||
@@ -87,6 +88,7 @@ def gqa_rope_write_cache(
|
||||
kv_token_num,
|
||||
max_seq_len,
|
||||
rms_norm_eps,
|
||||
use_neox_rotary_style,
|
||||
cache_quant_type,
|
||||
rope_3d,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user