mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-09-27 04:46:16 +08:00
fa3_rope (#4190)
This commit is contained in:
@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim) {
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
// do rope
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
|
||||
const int seq_len,
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const bool rope_3d,
|
||||
const cudaStream_t &stream) {
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int PackSize = 16 / sizeof(T);
|
||||
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head);
|
||||
dim_head,
|
||||
rope_3d);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const std::string& cache_quant_type) {
|
||||
const std::string& cache_quant_type,
|
||||
const bool rope_3d) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
typedef typename traits_::DataType DataType_;
|
||||
typedef typename traits_::data_t data_t;
|
||||
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
rotary_embs.dims()[2],
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
|
@@ -107,7 +107,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor> &cache_v_zp,
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const int kv_token_num, const int max_seq_len,
|
||||
const std::string &cache_quant_type);
|
||||
const std::string &cache_quant_type,
|
||||
const bool rope_3d);
|
||||
|
||||
std::vector<paddle::Tensor>
|
||||
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,
|
||||
|
@@ -283,6 +283,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
self.rope_3d,
|
||||
)
|
||||
|
||||
res = self.flash_attn_func(
|
||||
|
@@ -49,6 +49,7 @@ def gqa_rope_write_cache(
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
cache_quant_type: str = "none",
|
||||
rope_3d: bool = False,
|
||||
):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import gqa_rope_write_cache
|
||||
@@ -81,6 +82,7 @@ def gqa_rope_write_cache(
|
||||
kv_token_num,
|
||||
max_seq_len,
|
||||
cache_quant_type,
|
||||
rope_3d,
|
||||
)
|
||||
return q, k, v, qkv_
|
||||
else:
|
||||
|
Reference in New Issue
Block a user