mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
support fa3 rope3d (#3622)
This commit is contained in:
@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
|||||||
const int q_num_head,
|
const int q_num_head,
|
||||||
const int kv_num_head,
|
const int kv_num_head,
|
||||||
const int seq_len,
|
const int seq_len,
|
||||||
const int last_dim) {
|
const int last_dim,
|
||||||
|
const bool rope_3d) {
|
||||||
using LoadT = AlignedVector<T, VecSize>;
|
using LoadT = AlignedVector<T, VecSize>;
|
||||||
constexpr int HalfVecSize = VecSize / 2;
|
constexpr int HalfVecSize = VecSize / 2;
|
||||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||||
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
|||||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||||
|
|
||||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||||
|
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||||
const int64_t base_idx =
|
const int64_t base_idx =
|
||||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||||
h_bias;
|
h_bias;
|
||||||
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
|||||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||||
// do rope
|
// do rope
|
||||||
if (hi < q_num_head + kv_num_head) {
|
if (hi < q_num_head + kv_num_head) {
|
||||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < HalfVecSize; i++) {
|
for (int i = 0; i < HalfVecSize; i++) {
|
||||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||||
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
|
|||||||
const int seq_len,
|
const int seq_len,
|
||||||
const int input_output_len,
|
const int input_output_len,
|
||||||
const int dim_head,
|
const int dim_head,
|
||||||
|
const bool rope_3d,
|
||||||
const cudaStream_t &stream) {
|
const cudaStream_t &stream) {
|
||||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||||
constexpr int PackSize = 16 / sizeof(T);
|
constexpr int PackSize = 16 / sizeof(T);
|
||||||
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
|
|||||||
num_heads,
|
num_heads,
|
||||||
kv_num_heads,
|
kv_num_heads,
|
||||||
seq_len,
|
seq_len,
|
||||||
dim_head);
|
dim_head,
|
||||||
|
rope_3d);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T,
|
template <typename T,
|
||||||
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
|||||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||||
const int kv_token_num,
|
const int kv_token_num,
|
||||||
const int max_seq_len,
|
const int max_seq_len,
|
||||||
const std::string& cache_quant_type) {
|
const std::string& cache_quant_type,
|
||||||
|
const bool rope_3d) {
|
||||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||||
typedef typename traits_::DataType DataType_;
|
typedef typename traits_::DataType DataType_;
|
||||||
typedef typename traits_::data_t data_t;
|
typedef typename traits_::data_t data_t;
|
||||||
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
|||||||
num_heads,
|
num_heads,
|
||||||
kv_num_heads,
|
kv_num_heads,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
rotary_embs.dims()[2],
|
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||||
head_dim,
|
head_dim,
|
||||||
|
rope_3d,
|
||||||
stream);
|
stream);
|
||||||
|
|
||||||
if (token_num < kv_token_num) {
|
if (token_num < kv_token_num) {
|
||||||
|
@@ -111,7 +111,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
|||||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||||
const int kv_token_num, const int max_seq_len,
|
const int kv_token_num, const int max_seq_len,
|
||||||
const std::string &cache_quant_type);
|
const std::string &cache_quant_type,
|
||||||
|
const bool rope_3d);
|
||||||
|
|
||||||
std::vector<paddle::Tensor>
|
std::vector<paddle::Tensor>
|
||||||
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||||
|
@@ -311,6 +311,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_token_num_cpu[0].item(),
|
metadata.kv_token_num_cpu[0].item(),
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
getattr(layer, "cache_quant_type_str", "none"),
|
getattr(layer, "cache_quant_type_str", "none"),
|
||||||
|
self.rope_3d,
|
||||||
)
|
)
|
||||||
|
|
||||||
res_encoder = self.flash_attn_func(
|
res_encoder = self.flash_attn_func(
|
||||||
|
@@ -49,6 +49,7 @@ def gqa_rope_write_cache(
|
|||||||
kv_token_num: int = 1,
|
kv_token_num: int = 1,
|
||||||
max_seq_len: int = 0,
|
max_seq_len: int = 0,
|
||||||
cache_quant_type: str = "none",
|
cache_quant_type: str = "none",
|
||||||
|
rope_3d: bool = False,
|
||||||
):
|
):
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||||
@@ -81,6 +82,7 @@ def gqa_rope_write_cache(
|
|||||||
kv_token_num,
|
kv_token_num,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
cache_quant_type,
|
cache_quant_type,
|
||||||
|
rope_3d,
|
||||||
)
|
)
|
||||||
return q, k, v, qkv_
|
return q, k, v, qkv_
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user